Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remote GPU #74

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
vox-adv-cpk.pth.tar
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
miniconda3-latest
28 changes: 28 additions & 0 deletions arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--config", required=True, help="path to config")
parser.add_argument("--checkpoint", default='vox-cpk.pth.tar', help="path to checkpoint to restore")

parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates")
parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints")
parser.add_argument("--no-pad", dest="no_pad", action="store_true", help="don't pad output image")

parser.add_argument("--cam", type=int, default=0, help="Webcam device ID")
parser.add_argument("--virt-cam", type=int, default=0, help="Virtualcam device ID")
parser.add_argument("--no-stream", action="store_true", help="On Linux, force no streaming")

parser.add_argument("--verbose", action="store_true", help="Print additional information")

parser.add_argument("--avatars", default="./avatars", help="path to avatars directory")

parser.add_argument("--is-worker", action="store_true", help="Whether to run this process as a remote GPU worker")
parser.add_argument("--worker-port", type=int, default=5556, help="Which port to run the worker on")
parser.add_argument("--worker-host", type=str, default=None, help="Hostname of the worker")
parser.add_argument("--compress", action="store_true", help="Whether to compress messages to worker")

parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False)
parser.set_defaults(no_pad=False)

opt = parser.parse_args()
162 changes: 35 additions & 127 deletions cam_fomm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,18 @@
import glob
import yaml
import time
from argparse import ArgumentParser
import requests

import imageio
import numpy as np
from skimage.transform import resize
import cv2

import torch
from sync_batchnorm import DataParallelWithCallback

from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector
from animate import normalize_kp
from scipy.spatial import ConvexHull

import face_alignment

from videocaptureasync import VideoCaptureAsync
import predictor_local
import predictor_remote
from arguments import opt

from sys import platform as _platform
_streaming = False
Expand All @@ -29,68 +22,24 @@
_streaming = True


def load_checkpoints(config_path, checkpoint_path, device='cuda'):

with open(config_path) as f:
config = yaml.load(f)

generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
generator.to(device)

kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
kp_detector.to(device)

checkpoint = torch.load(checkpoint_path, map_location=device)
generator.load_state_dict(checkpoint['generator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])

generator = DataParallelWithCallback(generator)
kp_detector = DataParallelWithCallback(kp_detector)

generator.eval()
kp_detector.eval()

return generator, kp_detector

def normalize_alignment_kp(kp):
kp = kp - kp.mean(axis=0, keepdims=True)
area = ConvexHull(kp[:, :2]).volume
area = np.sqrt(area)
kp[:, :2] = kp[:, :2] / area
return kp

def get_frame_kp(fa, image):
kp_landmarks = fa.get_landmarks(255 * image)
if kp_landmarks:
kp_image = kp_landmarks[0]
kp_image = normalize_alignment_kp(kp_image)

return kp_image
else:
return None

def is_new_frame_better(fa, source, driving, device):
global start_frame
global start_frame_kp
def is_new_frame_better(source, driving, precitor):
global avatar_kp
global display_string

if avatar_kp is None:
display_string = "No face detected in avatar."
return False

if start_frame is None:
if predictor.get_start_frame() is None:
display_string = "No frame to compare to."
return True

driving_smaller = resize(driving, (128, 128))[..., :3]
new_kp = get_frame_kp(fa, driving)
new_kp = predictor.get_frame_kp(driving)

if new_kp is not None:
new_norm = (np.abs(avatar_kp - new_kp) ** 2).sum()
old_norm = (np.abs(avatar_kp - start_frame_kp) ** 2).sum()
old_norm = (np.abs(avatar_kp - predictor.get_start_frame_kp()) ** 2).sum()

out_string = "{0} : {1}".format(int(new_norm * 100), int(old_norm * 100))
display_string = out_string
Expand Down Expand Up @@ -119,33 +68,6 @@ def pad_img(img, orig):
return out


def predict(driving_frame, source_image, relative, adapt_movement_scale, fa, device='cuda'):
global start_frame
global start_frame_kp
global kp_driving_initial

with torch.no_grad():
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device)
driving = torch.tensor(driving_frame[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).to(device)
kp_source = kp_detector(source)

if kp_driving_initial is None:
kp_driving_initial = kp_detector(driving)
start_frame = driving_frame.copy()
start_frame_kp = get_frame_kp(fa, driving_frame)

kp_driving = kp_detector(driving)
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)

out = np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]
out = (np.clip(out, 0, 1) * 255).astype(np.uint8)

return out


def load_stylegan_avatar():
url = "https://thispersondoesnotexist.com/image"
r = requests.get(url, headers={'User-Agent': "My User Agent 1.0"}).content
Expand All @@ -158,48 +80,40 @@ def load_stylegan_avatar():

return image

def change_avatar(fa, new_avatar):
def change_avatar(predictor, new_avatar):
global avatar, avatar_kp
avatar_kp = get_frame_kp(fa, new_avatar)
avatar_kp = predictor.get_frame_kp(new_avatar)
avatar = new_avatar
predictor.set_source_image(avatar)

def log(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)


if __name__ == "__main__":

global display_string
display_string = ""
global kp_driving_initial
kp_driving_initial = None

parser = ArgumentParser()
parser.add_argument("--config", required=True, help="path to config")
parser.add_argument("--checkpoint", default='vox-cpk.pth.tar', help="path to checkpoint to restore")

parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates")
parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints")
parser.add_argument("--no-pad", dest="no_pad", action="store_true", help="don't pad output image")

parser.add_argument("--cam", type=int, default=0, help="Webcam device ID")
parser.add_argument("--virt-cam", type=int, default=0, help="Virtualcam device ID")
parser.add_argument("--no-stream", action="store_true", help="On Linux, force no streaming")

parser.add_argument("--verbose", action="store_true", help="Print additional information")

parser.add_argument("--avatars", default="./avatars", help="path to avatars directory")

parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False)
parser.set_defaults(no_pad=False)

opt = parser.parse_args()

if opt.no_stream:
log('Force no streaming')
_streaming = False

device = 'cuda' if torch.cuda.is_available() else 'cpu'
log('Loading Predictor')
if opt.is_worker:
predictor_remote.run_worker(opt.worker_port)
sys.exit(0)
elif opt.worker_host:
predictor = predictor_remote.PredictorRemote(
worker_host=opt.worker_host, worker_port=opt.worker_port,
config_path=opt.config, checkpoint_path=opt.checkpoint,
relative=opt.relative, adapt_movement_scale=opt.adapt_scale
)
else:
predictor = predictor_local.PredictorLocal(
config_path=opt.config, checkpoint_path=opt.checkpoint,
relative=opt.relative, adapt_movement_scale=opt.adapt_scale
)

avatars=[]
images_list = sorted(glob.glob(f'{opt.avatars}/*'))
Expand All @@ -212,12 +126,6 @@ def log(*args, **kwargs):
img = resize(img, (256, 256))[..., :3]
avatars.append(img)

log('load checkpoints..')

generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, device=device)

fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True, device=device)

# cap = cv2.VideoCapture(opt.cam)
cap = VideoCaptureAsync(opt.cam)
if not cap.isOpened():
Expand All @@ -235,7 +143,7 @@ def log(*args, **kwargs):

cur_ava = 0
avatar = None
change_avatar(fa, avatars[cur_ava])
change_avatar(predictor, avatars[cur_ava])
passthrough = False

cv2.namedWindow('cam', cv2.WINDOW_GUI_NORMAL)
Expand Down Expand Up @@ -270,10 +178,10 @@ def log(*args, **kwargs):
frame = resize(frame, (256, 256))[..., :3]

if find_keyframe:
if is_new_frame_better(fa, avatar, frame, device):
if is_new_frame_better(avatar, frame, predictor):
log("Taking new frame!")
green_overlay = True
kp_driving_initial = None
predictor.reset_frames()

if opt.verbose:
preproc_time = (time.time() - t_start) * 1000
Expand All @@ -283,7 +191,7 @@ def log(*args, **kwargs):
out = frame_orig[..., ::-1]
else:
pred_start = time.time()
pred = predict(frame, avatar, opt.relative, opt.adapt_scale, fa, device=device)
pred = predictor.predict(frame)
out = pred
pred_time = (time.time() - pred_start) * 1000
if opt.verbose:
Expand All @@ -306,21 +214,21 @@ def log(*args, **kwargs):
if cur_ava >= len(avatars):
cur_ava = 0
passthrough = False
change_avatar(fa, avatars[cur_ava])
change_avatar(predictor, avatars[cur_ava])
elif key == ord('a'):
cur_ava -= 1
if cur_ava < 0:
cur_ava = len(avatars) - 1
passthrough = False
change_avatar(fa, avatars[cur_ava])
change_avatar(predictor, avatars[cur_ava])
elif key == ord('w'):
frame_proportion -= 0.05
frame_proportion = max(frame_proportion, 0.1)
elif key == ord('s'):
frame_proportion += 0.05
frame_proportion = min(frame_proportion, 1.0)
elif key == ord('x'):
kp_driving_initial = None
predictor.reset_frames()
elif key == ord('z'):
overlay_alpha = max(overlay_alpha - 0.1, 0.0)
elif key == ord('c'):
Expand All @@ -336,15 +244,15 @@ def log(*args, **kwargs):
log('Loading StyleGAN avatar...')
avatar = load_stylegan_avatar()
passthrough = False
change_avatar(fa, avatar)
change_avatar(predictor, avatar)
except:
log('Failed to load StyleGAN avatar')
elif key == ord('i'):
show_fps = not show_fps
elif 48 < key < 58:
cur_ava = min(key - 49, len(avatars) - 1)
passthrough = False
change_avatar(fa, avatars[cur_ava])
change_avatar(predictor, avatars[cur_ava])
elif key == 48:
passthrough = not passthrough
elif key != -1:
Expand Down
Loading