-
Notifications
You must be signed in to change notification settings - Fork 13
/
nersemble.py
84 lines (69 loc) · 3.22 KB
/
nersemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#
# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
# property and proprietary rights in and to this software and related documentation.
# Any commercial use, reproduction, disclosure or distribution of this software and
# related documentation without an express license agreement from Toyota Motor Europe NV/SA
# is strictly prohibited.
#
from typing import Optional, Literal
from dataclasses import dataclass
import tyro
from vhap.config.base import (
StageRgbSequentialTrackingConfig, StageRgbGlobalTrackingConfig, PipelineConfig,
DataConfig, LossWeightConfig, BaseTrackingConfig,
)
from vhap.util.log import get_logger
logger = get_logger(__name__)
@dataclass()
class NersembleDataConfig(DataConfig):
_target: str = "vhap.data.nersemble_dataset.NeRSembleDataset"
calibrated: bool = True
background_color: Optional[Literal['white', 'black']] = None
landmark_source: Optional[Literal["face-alignment", 'star']] = "star"
subject: str = ""
"""Subject ID. Such as 018, 218, 251, 253"""
use_color_correction: bool = True
"""Whether to use color correction to harmonize the color of the input images."""
@dataclass()
class NersembleLossWeightConfig(LossWeightConfig):
landmark: Optional[float] = 3. # should not be lower to avoid collapse
always_enable_jawline_landmarks: bool = False # allow disable_jawline_landmarks in StageConfig to work
reg_expr: float = 1e-2 # for best expressivness
reg_tex_tv: Optional[float] = 1e5 # 10x of the base value
@dataclass()
class NersembleStageRgbSequentialTrackingConfig(StageRgbSequentialTrackingConfig):
optimizable_params: tuple[str, ...] = ("pose", "joints", "expr", "dynamic_offset")
align_texture_except: tuple[str, ...] = ("boundary",)
align_boundary_except: tuple[str, ...] = ("boundary",)
"""Due to the limited flexibility in the lower neck region of FLAME, we relax the
alignment constraints for better alignment in the face region.
"""
@dataclass()
class NersembleStageRgbGlobalTrackingConfig(StageRgbGlobalTrackingConfig):
align_texture_except: tuple[str, ...] = ("boundary",)
align_boundary_except: tuple[str, ...] = ("boundary",)
"""Due to the limited flexibility in the lower neck region of FLAME, we relax the
alignment constraints for better alignment in the face region.
"""
@dataclass()
class NersemblePipelineConfig(PipelineConfig):
rgb_sequential_tracking: NersembleStageRgbSequentialTrackingConfig
rgb_global_tracking: NersembleStageRgbGlobalTrackingConfig
@dataclass()
class NersembleTrackingConfig(BaseTrackingConfig):
data: NersembleDataConfig
w: NersembleLossWeightConfig
pipeline: NersemblePipelineConfig
def get_occluded(self):
occluded_table = {
'018': ('neck_lower',),
'218': ('neck_lower',),
'251': ('neck_lower', 'boundary'),
'253': ('neck_lower',),
}
if self.data.subject in occluded_table:
logger.info(f"Automatically setting cfg.model.occluded to {occluded_table[self.data.subject]}")
self.model.occluded = occluded_table[self.data.subject]
if __name__ == "__main__":
config = tyro.cli(NersembleTrackingConfig)
print(tyro.to_yaml(config))