Skip to content

Commit

Permalink
Revert "Revert "Port DW Pose preprocessor (#1856)" (#1860)"
Browse files Browse the repository at this point in the history
This reverts commit 17e100e.
  • Loading branch information
huchenlei authored Aug 4, 2023
1 parent 17e100e commit ba86578
Show file tree
Hide file tree
Showing 14 changed files with 787 additions and 54 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,5 @@ annotator/downloads/

# test results and expectations
web_tests/results/
web_tests/expectations/
web_tests/expectations/
*_diff.png
69 changes: 52 additions & 17 deletions annotator/openpose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,18 @@
from .body import Body, BodyResult, Keypoint
from .hand import Hand
from .face import Face
from .wholebody import Wholebody # DW Pose
from .types import PoseResult, HandResult, FaceResult
from modules import devices
from annotator.annotator_path import models_path

from typing import NamedTuple, Tuple, List, Callable, Union, Optional
from typing import Tuple, List, Callable, Union, Optional

body_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/body_pose_model.pth"
hand_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/hand_pose_model.pth"
face_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/facenet.pth"
remote_dw_model_path = "https://huggingface.co/camenduru/DWPose/resolve/main/dw-ll_ucoco_384.pth"

HandResult = List[Keypoint]
FaceResult = List[Keypoint]

class PoseResult(NamedTuple):
body: BodyResult
left_hand: Union[HandResult, None]
right_hand: Union[HandResult, None]
face: Union[FaceResult, None]

def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True):
"""
Expand Down Expand Up @@ -179,6 +174,8 @@ def __init__(self):
self.hand_estimation = None
self.face_estimation = None

self.dw_pose_estimation = None

def load_model(self):
"""
Load the Openpose body, hand, and face models.
Expand All @@ -198,10 +195,17 @@ def load_model(self):
if not os.path.exists(face_modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(face_model_path, model_dir=self.model_dir)

self.body_estimation = Body(body_modelpath)
self.hand_estimation = Hand(hand_modelpath)
self.face_estimation = Face(face_modelpath)

def load_dw_model(self):
dw_modelpath = os.path.join(self.model_dir, "dw-ll_ucoco_384.pth")
if not os.path.exists(dw_modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_dw_model_path, model_dir=self.model_dir)
self.dw_pose_estimation = Wholebody(dw_modelpath, device=self.device)

def unload_model(self):
"""
Expand All @@ -211,6 +215,11 @@ def unload_model(self):
self.body_estimation.model.to("cpu")
self.hand_estimation.model.to("cpu")
self.face_estimation.model.to("cpu")

def unload_dw_model(self):
if self.dw_pose_estimation is not None:
self.dw_pose_estimation.detector.to("cpu")
self.dw_pose_estimation.pose_estimator.to("cpu")

def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]:
left_hand = None
Expand Down Expand Up @@ -269,7 +278,7 @@ def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[P

self.body_estimation.model.to(self.device)
self.hand_estimation.model.to(self.device)
self.face_estimation.model.to(self.device)
self.face_estimation.model.to(self.device)

self.body_estimation.cn_device = self.device
self.hand_estimation.cn_device = self.device
Expand Down Expand Up @@ -302,10 +311,31 @@ def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[P
), left_hand, right_hand, face))

return results


def detect_poses_dw(self, oriImg) -> List[PoseResult]:
"""
Detect poses in the given image using DW Pose:
https://github.com/IDEA-Research/DWPose
Args:
oriImg (numpy.ndarray): The input image for pose detection.
Returns:
List[PoseResult]: A list of PoseResult objects containing the detected poses.
"""
if self.dw_pose_estimation is None:
self.load_dw_model()

self.dw_pose_estimation.detector.to(self.device)
self.dw_pose_estimation.pose_estimator.to(self.device)

with torch.no_grad():
keypoints_info = self.dw_pose_estimation(oriImg.copy())
return Wholebody.format_result(keypoints_info)

def __call__(
self, oriImg, include_body=True, include_hand=False, include_face=False,
json_pose_callback: Callable[[str], None] = None,
self, oriImg, include_body=True, include_hand=False, include_face=False,
use_dw_pose=False, json_pose_callback: Callable[[str], None] = None,
):
"""
Detect and draw poses in the given image.
Expand All @@ -315,14 +345,19 @@ def __call__(
include_body (bool, optional): Whether to include body keypoints. Defaults to True.
include_hand (bool, optional): Whether to include hand keypoints. Defaults to False.
include_face (bool, optional): Whether to include face keypoints. Defaults to False.
use_dw_pose (bool, optional): Whether to use DW pose detection algorithm. Defaults to False.
json_pose_callback (Callable, optional): A callback that accepts the pose JSON string.
Returns:
numpy.ndarray: The image with detected and drawn poses.
"""
H, W, _ = oriImg.shape
poses = self.detect_poses(oriImg, include_hand, include_face)

if use_dw_pose:
poses = self.detect_poses_dw(oriImg)
else:
poses = self.detect_poses(oriImg, include_hand, include_face)

if json_pose_callback:
json_pose_callback(encode_poses_as_json(poses, H, W))
return draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face)

return draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face)
19 changes: 1 addition & 18 deletions annotator/openpose/body.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,7 @@

from . import util
from .model import bodypose_model

class Keypoint(NamedTuple):
x: float
y: float
score: float = 1.0
id: int = -1


class BodyResult(NamedTuple):
# Note: Using `Union` instead of `|` operator as the ladder is a Python
# 3.10 feature.
# Annotator code should be Python 3.8 Compatible, as controlnet repo uses
# Python 3.8 environment.
# https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6
keypoints: List[Union[Keypoint, None]]
total_score: float = 0.0
total_parts: int = 0

from .types import Keypoint, BodyResult

class Body(object):
def __init__(self, model_path):
Expand Down
Loading

0 comments on commit ba86578

Please sign in to comment.