Skip to content

Commit

Permalink
1.modify wav2lip api; 2.add 512*512 params of fom (PaddlePaddle#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
lzzyzlbb authored Jun 8, 2021
1 parent 1f335bb commit 170a2b0
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 47 deletions.
9 changes: 7 additions & 2 deletions applications/tools/first-order-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@
action="store_true",
default=False,
help="whether there is only one person in the image or not")

parser.add_argument("--image_size",
dest="image_size",
type=int,
default=256,
help="size of image")
parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False)

Expand All @@ -87,5 +91,6 @@
best_frame=args.best_frame,
ratio=args.ratio,
face_detector=args.face_detector,
multi_person=args.multi_person)
multi_person=args.multi_person,
image_size=args.image_size)
predictor.run(args.source_image, args.driving_video)
15 changes: 13 additions & 2 deletions applications/tools/wav2lip.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,16 @@
if args.cpu:
paddle.set_device('cpu')

predictor = Wav2LipPredictor(args)
predictor.run()
predictor = Wav2LipPredictor(checkpoint_path = args.checkpoint_path,
static = args.static,
fps = args.fps,
pads = args.pads,
face_det_batch_size = args.face_det_batch_size,
wav2lip_batch_size = args.wav2lip_batch_size,
resize_factor = args.resize_factor,
crop = args.crop,
box = args.box,
rotate = args.rotate,
nosmooth = args.nosmooth,
face_detector = args.face_detector)
predictor.run(args.face, args.audio, args.outfile)
6 changes: 3 additions & 3 deletions docs/zh_CN/apis/apps.md
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ ppgan.apps.MiDaSPredictor(output=None, weight_path=None)
## ppgan.apps.Wav2lipPredictor
```python
ppgan.apps.FirstOrderPredictor(args)
ppgan.apps.FirstOrderPredictor()
```
> 构建Wav2lip模型的实例,此模型用来做唇形合成,即给定一个人物视频和一个音频,实现人物口型与输入语音同步。论文是A Lip Sync Expert Is All You Need for Speech to Lip Generation In the Wild,论文链接: http://arxiv.org/abs/2008.10010.
Expand All @@ -449,8 +449,8 @@ ppgan.apps.FirstOrderPredictor(args)
> ```
> from ppgan.apps import Wav2LipPredictor
> # The args parameter should be specified by argparse
> predictor = Wav2LipPredictor(args)
> predictor.run()
> predictor = Wav2LipPredictor()
> predictor.run(face, audio, outfile)
> ```
> **参数:**
Expand Down
14 changes: 9 additions & 5 deletions ppgan/apps/first_order_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

from .base_predictor import BasePredictor

IMAGE_SIZE = 256

class FirstOrderPredictor(BasePredictor):
def __init__(self,
Expand All @@ -47,7 +46,8 @@ def __init__(self,
ratio=1.0,
filename='result.mp4',
face_detector='sfd',
multi_person=False):
multi_person=False,
image_size=256):
if config is not None and isinstance(config, str):
with open(config) as f:
self.cfg = yaml.load(f, Loader=yaml.SafeLoader)
Expand Down Expand Up @@ -85,8 +85,12 @@ def __init__(self,
}
}
}
self.image_size = image_size
if weight_path is None:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk.pdparams'
if self.image_size == 512:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk-512.pdparams'
else:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk.pdparams'
weight_path = get_path_from_url(vox_cpk_weight_url)

self.weight_path = weight_path
Expand Down Expand Up @@ -161,7 +165,7 @@ def get_prediction(face_image):
reader.close()

driving_video = [
cv2.resize(frame, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0 for frame in driving_video
cv2.resize(frame, (self.image_size, self.image_size)) / 255.0 for frame in driving_video
]
results = []

Expand All @@ -171,7 +175,7 @@ def get_prediction(face_image):
# for multi person
for rec in bboxes:
face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]]
face_image = cv2.resize(face_image, (IMAGE_SIZE, IMAGE_SIZE)) / 255.0
face_image = cv2.resize(face_image, (self.image_size, self.image_size)) / 255.0
predictions = get_prediction(face_image)
results.append({'rec': rec, 'predict': predictions})
if len(bboxes) == 1 or not self.multi_person:
Expand Down
92 changes: 57 additions & 35 deletions ppgan/apps/wav2lip_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,31 @@


class Wav2LipPredictor(BasePredictor):
def __init__(self, args):
self.args = args
if os.path.isfile(self.args.face) and path.basename(
self.args.face).split('.')[1] in ['jpg', 'png', 'jpeg']:
self.args.static = True
def __init__(self, checkpoint_path = None,
static = False,
fps = 25,
pads = [0, 10, 0, 0],
face_det_batch_size = 16,
wav2lip_batch_size = 128,
resize_factor = 1,
crop = [0, -1, 0, -1],
box = [-1, -1, -1, -1],
rotate = False,
nosmooth = False,
face_detector = 'sfd'):
self.img_size = 96
self.checkpoint_path = checkpoint_path
self.static = static
self.fps = fps,
self.pads = pads
self.face_det_batch_size = face_det_batch_size
self.wav2lip_batch_size = wav2lip_batch_size
self.resize_factor = resize_factor
self.crop = crop
self.box = box
self.rotate = rotate
self.nosmooth = nosmooth
self.face_detector = face_detector
makedirs('./temp', exist_ok=True)

def get_smoothened_boxes(self, boxes, T):
Expand All @@ -38,9 +57,9 @@ def face_detect(self, images):
detector = face_detection.FaceAlignment(
face_detection.LandmarksType._2D,
flip_input=False,
face_detector=self.args.face_detector)
face_detector=self.face_detector)

batch_size = self.args.face_det_batch_size
batch_size = self.face_det_batch_size

while 1:
predictions = []
Expand All @@ -61,7 +80,7 @@ def face_detect(self, images):
break

results = []
pady1, pady2, padx1, padx2 = self.args.pads
pady1, pady2, padx1, padx2 = self.pads
for rect, image in zip(predictions, images):
if rect is None:
cv2.imwrite(
Expand All @@ -79,7 +98,7 @@ def face_detect(self, images):
results.append([x1, y1, x2, y2])

boxes = np.array(results)
if not self.args.nosmooth: boxes = self.get_smoothened_boxes(boxes, T=5)
if not self.nosmooth: boxes = self.get_smoothened_boxes(boxes, T=5)
results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)]
for image, (x1, y1, x2, y2) in zip(images, boxes)]

Expand All @@ -89,21 +108,21 @@ def face_detect(self, images):
def datagen(self, frames, mels):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []

if self.args.box[0] == -1:
if not self.args.static:
if self.box[0] == -1:
if not self.static:
face_det_results = self.face_detect(
frames) # BGR2RGB for CNN face detection
else:
face_det_results = self.face_detect([frames[0]])
else:
print(
'Using the specified bounding box instead of face detection...')
y1, y2, x1, x2 = self.args.box
y1, y2, x1, x2 = self.box
face_det_results = [[f[y1:y2, x1:x2], (y1, y2, x1, x2)]
for f in frames]

for i, m in enumerate(mels):
idx = 0 if self.args.static else i % len(frames)
idx = 0 if self.static else i % len(frames)
frame_to_save = frames[idx].copy()
face, coords = face_det_results[idx].copy()

Expand All @@ -114,7 +133,7 @@ def datagen(self, frames, mels):
frame_batch.append(frame_to_save)
coords_batch.append(coords)

if len(img_batch) >= self.args.wav2lip_batch_size:
if len(img_batch) >= self.wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(
mel_batch)

Expand Down Expand Up @@ -143,18 +162,22 @@ def datagen(self, frames, mels):

yield img_batch, mel_batch, frame_batch, coords_batch

def run(self):
if not os.path.isfile(self.args.face):
def run(self, face, audio_seq, outfile):
if os.path.isfile(face) and path.basename(
face).split('.')[1] in ['jpg', 'png', 'jpeg']:
self.static = True

if not os.path.isfile(face):
raise ValueError(
'--face argument must be a valid path to video/image file')

elif path.basename(
self.args.face).split('.')[1] in ['jpg', 'png', 'jpeg']:
full_frames = [cv2.imread(self.args.face)]
fps = self.args.fps
face).split('.')[1] in ['jpg', 'png', 'jpeg']:
full_frames = [cv2.imread(face)]
fps = self.fps

else:
video_stream = cv2.VideoCapture(self.args.face)
video_stream = cv2.VideoCapture(face)
fps = video_stream.get(cv2.CAP_PROP_FPS)

print('Reading video frames...')
Expand All @@ -165,15 +188,15 @@ def run(self):
if not still_reading:
video_stream.release()
break
if self.args.resize_factor > 1:
if self.resize_factor > 1:
frame = cv2.resize(
frame, (frame.shape[1] // self.args.resize_factor,
frame.shape[0] // self.args.resize_factor))
frame, (frame.shape[1] // self.resize_factor,
frame.shape[0] // self.resize_factor))

if self.args.rotate:
if self.rotate:
frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE)

y1, y2, x1, x2 = self.args.crop
y1, y2, x1, x2 = self.crop
if x2 == -1: x2 = frame.shape[1]
if y2 == -1: y2 = frame.shape[0]

Expand All @@ -184,18 +207,16 @@ def run(self):
print("Number of frames available for inference: " +
str(len(full_frames)))

if not self.args.audio.endswith('.wav'):
if not audio_seq.endswith('.wav'):
print('Extracting raw audio...')
command = 'ffmpeg -y -i {} -strict -2 {}'.format(
self.args.audio, 'temp/temp.wav')
audio_seq, 'temp/temp.wav')

subprocess.call(command, shell=True)
self.args.audio = 'temp/temp.wav'
audio_seq = 'temp/temp.wav'

wav = audio.load_wav(self.args.audio, 16000)
wav = audio.load_wav(audio_seq, 16000)
mel = audio.melspectrogram(wav)
print(mel.shape)

if np.isnan(mel.reshape(-1)).sum() > 0:
raise ValueError(
'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again'
Expand All @@ -216,15 +237,15 @@ def run(self):

full_frames = full_frames[:len(mel_chunks)]

batch_size = self.args.wav2lip_batch_size
batch_size = self.wav2lip_batch_size
gen = self.datagen(full_frames.copy(), mel_chunks)

model = Wav2Lip()
if self.args.checkpoint_path is None:
if self.checkpoint_path is None:
model_weights_path = get_weights_path_from_url(WAV2LIP_WEIGHT_URL)
weights = paddle.load(model_weights_path)
else:
weights = paddle.load(self.args.checkpoint_path)
weights = paddle.load(self.checkpoint_path)
model.load_dict(weights)
model.eval()
print("Model loaded")
Expand Down Expand Up @@ -258,5 +279,6 @@ def run(self):
out.release()

command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(
self.args.audio, 'temp/result.avi', self.args.outfile)
audio_seq, 'temp/result.avi', outfile)
subprocess.call(command, shell=platform.system() != 'Windows')

0 comments on commit 170a2b0

Please sign in to comment.