-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Unify SMPL-like models to mesh models (#830)
* Add unified SMPL-like model interface and builder
- Loading branch information
1 parent
ce68b4c
commit 0ecf06e
Showing
9 changed files
with
289 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,15 @@ | ||
from .backbones import * # noqa | ||
from .builder import (BACKBONES, HEADS, LOSSES, NECKS, POSENETS, | ||
build_backbone, build_head, build_loss, build_neck, | ||
build_posenet) | ||
from .builder import (BACKBONES, HEADS, LOSSES, MESH_MODELS, NECKS, POSENETS, | ||
build_backbone, build_head, build_loss, build_mesh_model, | ||
build_neck, build_posenet) | ||
from .detectors import * # noqa | ||
from .heads import * # noqa | ||
from .losses import * # noqa | ||
from .necks import * # noqa | ||
from .utils import * # noqa | ||
|
||
__all__ = [ | ||
'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'POSENETS', 'build_backbone', | ||
'build_head', 'build_loss', 'build_posenet', 'build_neck' | ||
'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'POSENETS', 'MESH_MODELS', | ||
'build_backbone', 'build_head', 'build_loss', 'build_posenet', | ||
'build_neck', 'build_mesh_model' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .smpl import SMPL | ||
|
||
__all__ = ['SMPL'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
from ..builder import MESH_MODELS | ||
|
||
try: | ||
from smplx import SMPL as SMPL_ | ||
has_smpl = True | ||
except (ImportError, ModuleNotFoundError): | ||
has_smpl = False | ||
|
||
|
||
@MESH_MODELS.register_module() | ||
class SMPL(nn.Module): | ||
"""SMPL 3d human mesh model of paper ref: Matthew Loper. ``SMPL: A skinned | ||
multi-person linear model''. This module is based on the smplx project | ||
(https://github.com/vchoutas/smplx). | ||
Args: | ||
smpl_path (str): The path to the folder where the model weights are | ||
stored. | ||
joints_regressor (str): The path to the file where the joints | ||
regressor weight are stored. | ||
""" | ||
|
||
def __init__(self, smpl_path, joints_regressor): | ||
super().__init__() | ||
|
||
assert has_smpl, 'Please install smplx to use SMPL.' | ||
|
||
self.smpl_neutral = SMPL_( | ||
model_path=smpl_path, | ||
create_global_orient=False, | ||
create_body_pose=False, | ||
create_transl=False, | ||
gender='neutral') | ||
|
||
self.smpl_male = SMPL_( | ||
model_path=smpl_path, | ||
create_betas=False, | ||
create_global_orient=False, | ||
create_body_pose=False, | ||
create_transl=False, | ||
gender='male') | ||
|
||
self.smpl_female = SMPL_( | ||
model_path=smpl_path, | ||
create_betas=False, | ||
create_global_orient=False, | ||
create_body_pose=False, | ||
create_transl=False, | ||
gender='female') | ||
|
||
joints_regressor = torch.tensor( | ||
np.load(joints_regressor), dtype=torch.float)[None, ...] | ||
self.register_buffer('joints_regressor', joints_regressor) | ||
|
||
self.num_verts = self.smpl_neutral.get_num_verts() | ||
self.num_joints = self.joints_regressor.shape[1] | ||
|
||
def smpl_forward(self, model, **kwargs): | ||
"""Apply a specific SMPL model with given model parameters. | ||
Note: | ||
B: batch size | ||
V: number of vertices | ||
K: number of joints | ||
Returns: | ||
outputs (dict): Dict with mesh vertices and joints. | ||
- vertices: Tensor([B, V, 3]), mesh vertices | ||
- joints: Tensor([B, K, 3]), 3d joints regressed | ||
from mesh vertices. | ||
""" | ||
|
||
betas = kwargs['betas'] | ||
batch_size = betas.shape[0] | ||
device = betas.device | ||
output = {} | ||
if batch_size == 0: | ||
output['vertices'] = betas.new_zeros([0, self.num_verts, 3]) | ||
output['joints'] = betas.new_zeros([0, self.num_joints, 3]) | ||
else: | ||
smpl_out = model(**kwargs) | ||
output['vertices'] = smpl_out.vertices | ||
output['joints'] = torch.matmul( | ||
self.joints_regressor.to(device), output['vertices']) | ||
return output | ||
|
||
def get_faces(self): | ||
"""Return mesh faces. | ||
Note: | ||
F: number of faces | ||
Returns: | ||
faces: np.ndarray([F, 3]), mesh faces | ||
""" | ||
return self.smpl_neutral.faces | ||
|
||
def forward(self, | ||
betas, | ||
body_pose, | ||
global_orient, | ||
transl=None, | ||
gender=None): | ||
"""Forward function. | ||
Note: | ||
B: batch size | ||
J: number of controllable joints of model, for smpl model J=23 | ||
K: number of joints | ||
Args: | ||
betas: Tensor([B, 10]), human body shape parameters of SMPL model. | ||
body_pose: Tensor([B, J*3] or [B, J, 3, 3]), human body pose | ||
parameters of SMPL model. It should be axis-angle vector | ||
([B, J*3]) or rotation matrix ([B, J, 3, 3)]. | ||
global_orient: Tensor([B, 3] or [B, 1, 3, 3]), global orientation | ||
of human body. It should be axis-angle vector ([B, 3]) or | ||
rotation matrix ([B, 1, 3, 3)]. | ||
transl: Tensor([B, 3]), global translation of human body. | ||
gender: Tensor([B]), gender parameters of human body. -1 for | ||
neutral, 0 for male , 1 for female. | ||
Returns: | ||
outputs (dict): Dict with mesh vertices and joints. | ||
- vertices: Tensor([B, V, 3]), mesh vertices | ||
- joints: Tensor([B, K, 3]), 3d joints regressed from | ||
mesh vertices. | ||
""" | ||
|
||
batch_size = betas.shape[0] | ||
pose2rot = True if body_pose.dim() == 2 else False | ||
if batch_size > 0 and gender is not None: | ||
output = { | ||
'vertices': betas.new_zeros([batch_size, self.num_verts, 3]), | ||
'joints': betas.new_zeros([batch_size, self.num_joints, 3]) | ||
} | ||
|
||
mask = gender < 0 | ||
_out = self.smpl_forward( | ||
self.smpl_neutral, | ||
betas=betas[mask], | ||
body_pose=body_pose[mask], | ||
global_orient=global_orient[mask], | ||
transl=transl[mask] if transl is not None else None, | ||
pose2rot=pose2rot) | ||
output['vertices'][mask] = _out['vertices'] | ||
output['joints'][mask] = _out['joints'] | ||
|
||
mask = gender == 0 | ||
_out = self.smpl_forward( | ||
self.smpl_male, | ||
betas=betas[mask], | ||
body_pose=body_pose[mask], | ||
global_orient=global_orient[mask], | ||
transl=transl[mask] if transl is not None else None, | ||
pose2rot=pose2rot) | ||
output['vertices'][mask] = _out['vertices'] | ||
output['joints'][mask] = _out['joints'] | ||
|
||
mask = gender == 1 | ||
_out = self.smpl_forward( | ||
self.smpl_male, | ||
betas=betas[mask], | ||
body_pose=body_pose[mask], | ||
global_orient=global_orient[mask], | ||
transl=transl[mask] if transl is not None else None, | ||
pose2rot=pose2rot) | ||
output['vertices'][mask] = _out['vertices'] | ||
output['joints'][mask] = _out['joints'] | ||
else: | ||
return self.smpl_forward( | ||
self.smpl_neutral, | ||
betas=betas, | ||
body_pose=body_pose, | ||
global_orient=global_orient, | ||
transl=transl, | ||
pose2rot=pose2rot) | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import numpy as np | ||
import torch | ||
from tests.test_model.test_mesh_forward import generate_smpl_weight_file | ||
|
||
from mmpose.models.utils import SMPL | ||
|
||
|
||
def test_smpl(): | ||
"""Test smpl model.""" | ||
|
||
# generate weight file for SMPL model. | ||
generate_smpl_weight_file('tests/data/smpl') | ||
|
||
# build smpl model | ||
smpl_cfg = dict( | ||
smpl_path='tests/data/smpl', | ||
joints_regressor='tests/data/smpl/test_joint_regressor.npy') | ||
smpl = SMPL(**smpl_cfg) | ||
|
||
# test get face function | ||
faces = smpl.get_faces() | ||
assert isinstance(faces, np.ndarray) | ||
|
||
betas = torch.zeros(3, 10) | ||
body_pose = torch.zeros(3, 23 * 3) | ||
global_orient = torch.zeros(3, 3) | ||
transl = torch.zeros(3, 3) | ||
gender = torch.LongTensor([-1, 0, 1]) | ||
|
||
# test forward with body_pose and global_orient in axis-angle format | ||
smpl_out = smpl( | ||
betas=betas, body_pose=body_pose, global_orient=global_orient) | ||
assert isinstance(smpl_out, dict) | ||
assert smpl_out['vertices'].shape == torch.Size([3, 6890, 3]) | ||
assert smpl_out['joints'].shape == torch.Size([3, 24, 3]) | ||
|
||
# test forward with body_pose and global_orient in rotation matrix format | ||
body_pose = torch.eye(3).repeat([3, 23, 1, 1]) | ||
global_orient = torch.eye(3).repeat([3, 1, 1, 1]) | ||
_ = smpl(betas=betas, body_pose=body_pose, global_orient=global_orient) | ||
|
||
# test forward with translation | ||
_ = smpl( | ||
betas=betas, | ||
body_pose=body_pose, | ||
global_orient=global_orient, | ||
transl=transl) | ||
|
||
# test forward with gender | ||
_ = smpl( | ||
betas=betas, | ||
body_pose=body_pose, | ||
global_orient=global_orient, | ||
transl=transl, | ||
gender=gender) | ||
|
||
# test forward when all samples in the same gender | ||
gender = torch.LongTensor([0, 0, 0]) | ||
_ = smpl( | ||
betas=betas, | ||
body_pose=body_pose, | ||
global_orient=global_orient, | ||
transl=transl, | ||
gender=gender) | ||
|
||
# test forward when batch size = 0 | ||
_ = smpl( | ||
betas=torch.zeros(0, 10), | ||
body_pose=torch.zeros(0, 23 * 3), | ||
global_orient=torch.zeros(0, 3)) |
Oops, something went wrong.