Skip to content

Commit

Permalink
Support PuLID (#2838)
Browse files Browse the repository at this point in the history
* Add preprocessors

* Fix resolution param

* Fix various issues

* Add PuLID attn

* remove unused import

* Resize img before passing to facexlib

* safe unload
  • Loading branch information
huchenlei authored May 4, 2024
1 parent 36a310f commit 784b6d0
Show file tree
Hide file tree
Showing 18 changed files with 571 additions and 80 deletions.
7 changes: 6 additions & 1 deletion internal_controlnet/external_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from modules.safe import unsafe_torch_load
from scripts import global_state
from scripts.logging import logger
from scripts.enums import HiResFixOption
from scripts.enums import HiResFixOption, PuLIDMode
from scripts.supported_preprocessor import Preprocessor, PreprocessorParameter

from modules.api import api
Expand Down Expand Up @@ -207,6 +207,10 @@ class ControlNetUnit:
# The effective region mask that unit's effect should be restricted to.
effective_region_mask: Optional[np.ndarray] = None

# The weight mode for PuLID.
# https://github.com/ToTheBeginning/PuLID
pulid_mode: PuLIDMode = PuLIDMode.FIDELITY

# The tensor input for ipadapter. When this field is set in the API,
# the base64string will be interpret by torch.load to reconstruct ipadapter
# preprocessor output.
Expand Down Expand Up @@ -243,6 +247,7 @@ def infotext_excluded_fields() -> List[str]:
# provide much information when restoring the unit.
"inpaint_crop_input_image",
"effective_region_mask",
"pulid_mode",
]

@property
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ addict
yapf
albumentations==1.4.3
matplotlib
facexlib
2 changes: 1 addition & 1 deletion scripts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def accept(self, json_dict: dict) -> None:
low_vram=low_vram,
)
if preprocessor.returns_image:
images.append(encode_to_base64(result.display_image))
images.append(encode_to_base64(result.display_images[0]))
else:
tensors.append(encode_tensor_to_base64(result.value))

Expand Down
22 changes: 17 additions & 5 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import scripts.preprocessor as preprocessor_init # noqa
from annotator.util import HWC3
from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils
from internal_controlnet.external_code import ControlMode
from scripts.controlnet_lora import bind_control_lora, unbind_control_lora
from scripts.controlnet_lllite import clear_all_lllite
from scripts.ipadapter.plugable_ipadapter import ImageEmbed, clear_all_ip_adapter
from scripts.ipadapter.pulid_attn import PULID_SETTING_FIDELITY, PULID_SETTING_STYLE
from scripts.utils import load_state_dict, get_unique_axis0, align_dim_latent
from scripts.hook import ControlParams, UnetHook, HackedImageRNG
from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption
from scripts.enums import ControlModelType, StableDiffusionVersion, HiResFixOption, PuLIDMode
from scripts.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
from scripts.controlnet_ui.photopea import Photopea
from scripts.logging import logger
Expand Down Expand Up @@ -279,6 +281,7 @@ def preprocess_input_image(input_image: np.ndarray):
)
detected_map = result.value
is_image = preprocessor.returns_image
# TODO: Refactor img control detection logic.
if high_res_fix:
if is_image:
hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x)
Expand All @@ -293,7 +296,8 @@ def preprocess_input_image(input_image: np.ndarray):
store_detected_map(detected_map, unit.module)
else:
control = detected_map
store_detected_map(input_image, unit.module)
for image in result.display_images:
store_detected_map(image, unit.module)

if control_model_type == ControlModelType.T2I_StyleAdapter:
control = control['last_hidden_state']
Expand Down Expand Up @@ -1092,8 +1096,8 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe
global_average_pooling=global_average_pooling,
hr_hint_cond=hr_control,
hr_option=HiResFixOption.from_value(unit.hr_option) if high_res_fix else HiResFixOption.BOTH,
soft_injection=control_mode != external_code.ControlMode.BALANCED,
cfg_injection=control_mode == external_code.ControlMode.CONTROL,
soft_injection=control_mode != ControlMode.BALANCED,
cfg_injection=control_mode == ControlMode.CONTROL,
effective_region_mask=(
get_pytorch_control(unit.effective_region_mask)[:, 0:1, :, :]
if unit.effective_region_mask is not None
Expand Down Expand Up @@ -1190,7 +1194,7 @@ def recolor_intensity_post_processing(x, i):

is_low_vram = any(unit.low_vram for unit in self.enabled_units)

for i, param in enumerate(forward_params):
for i, (param, unit) in enumerate(zip(forward_params, self.enabled_units)):
if param.control_model_type == ControlModelType.IPAdapter:
if param.advanced_weighting is not None:
logger.info(f"IP-Adapter using advanced weighting {param.advanced_weighting}")
Expand All @@ -1205,6 +1209,13 @@ def recolor_intensity_post_processing(x, i):
weight = param.weight

h, w, hr_y, hr_x = Script.get_target_dimensions(p)
pulid_mode = PuLIDMode(unit.pulid_mode)
if pulid_mode == PuLIDMode.STYLE:
pulid_attn_setting = PULID_SETTING_STYLE
else:
assert pulid_mode == PuLIDMode.FIDELITY
pulid_attn_setting = PULID_SETTING_FIDELITY

param.control_model.hook(
model=unet,
preprocessor_outputs=param.hint_cond,
Expand All @@ -1215,6 +1226,7 @@ def recolor_intensity_post_processing(x, i):
latent_width=w // 8,
latent_height=h // 8,
effective_region_mask=param.effective_region_mask,
pulid_attn_setting=pulid_attn_setting,
)
if param.control_model_type == ControlModelType.Controlllite:
param.control_model.hook(
Expand Down
24 changes: 22 additions & 2 deletions scripts/controlnet_ui/controlnet_ui_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from scripts.controlnet_ui.preset import ControlNetPresetUI
from scripts.controlnet_ui.photopea import Photopea
from scripts.controlnet_ui.advanced_weight_control import AdvancedWeightControl
from scripts.enums import InputMode
from scripts.enums import InputMode, PuLIDMode
from modules import shared
from modules.ui_components import FormRow, FormHTML, ToolButton

Expand Down Expand Up @@ -287,6 +287,7 @@ def __init__(
self.batch_image_dir_state = None
self.output_dir_state = None
self.advanced_weighting = gr.State(None)
self.pulid_mode = None

# API-only fields
self.ipadapter_input = gr.State(None)
Expand Down Expand Up @@ -626,6 +627,15 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
visible=False,
)

self.pulid_mode = gr.Radio(
choices=[e.value for e in PuLIDMode],
value=self.default_unit.pulid_mode.value,
label="PuLID Mode",
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_pulid_mode_radio",
elem_classes="controlnet_pulid_mode_radio",
visible=False,
)

self.loopback = gr.Checkbox(
label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation",
value=self.default_unit.loopback,
Expand Down Expand Up @@ -673,6 +683,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
self.save_detected_map,
self.advanced_weighting,
self.effective_region_mask,
self.pulid_mode,
)

unit = gr.State(self.default_unit)
Expand Down Expand Up @@ -947,7 +958,7 @@ def is_openpose(module: str):

return (
# Update to `generated_image`
gr.update(value=result.display_image, visible=True, interactive=False),
gr.update(value=result.display_images[0], visible=True, interactive=False),
# preprocessor_preview
gr.update(value=True),
# openpose editor
Expand Down Expand Up @@ -1118,6 +1129,14 @@ def register_shift_upload_mask(self):
show_progress=False,
)

def register_shift_pulid_mode(self):
self.model.change(
fn=lambda model: gr.update(visible="pulid" in model.lower()),
inputs=[self.model],
outputs=[self.pulid_mode],
show_progress=False,
)

def register_sync_batch_dir(self):
def determine_batch_dir(batch_dir, fallback_dir, fallback_fallback_dir):
if batch_dir:
Expand Down Expand Up @@ -1220,6 +1239,7 @@ def register_core_callbacks(self):
self.register_build_sliders()
self.register_shift_preview()
self.register_shift_upload_mask()
self.register_shift_pulid_mode()
self.register_create_canvas()
self.register_clear_preview()
self.register_multi_images_upload()
Expand Down
5 changes: 5 additions & 0 deletions scripts/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,8 @@ class InputMode(Enum):
# Input is a directory. 1 generation. Each generation takes N input image
# from the directory.
MERGE = "merge"


class PuLIDMode(Enum):
FIDELITY = "Fidelity"
STYLE = "Extremely style"
62 changes: 62 additions & 0 deletions scripts/ipadapter/image_proj_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,65 @@ def forward(self, x):

latents = self.proj_out(latents)
return self.norm_out(latents)


class PuLIDEncoder(nn.Module):
def __init__(self, width=1280, context_dim=2048, num_token=5):
super().__init__()
self.num_token = num_token
self.context_dim = context_dim
h1 = min((context_dim * num_token) // 4, 1024)
h2 = min((context_dim * num_token) // 2, 1024)
self.body = nn.Sequential(
nn.Linear(width, h1),
nn.LayerNorm(h1),
nn.LeakyReLU(),
nn.Linear(h1, h2),
nn.LayerNorm(h2),
nn.LeakyReLU(),
nn.Linear(h2, context_dim * num_token),
)

for i in range(5):
setattr(
self,
f"mapping_{i}",
nn.Sequential(
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, context_dim),
),
)

setattr(
self,
f"mapping_patch_{i}",
nn.Sequential(
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, 1024),
nn.LayerNorm(1024),
nn.LeakyReLU(),
nn.Linear(1024, context_dim),
),
)

def forward(self, x, y):
# x shape [N, C]
x = self.body(x)
x = x.reshape(-1, self.num_token, self.context_dim)

hidden_states = ()
for i, emb in enumerate(y):
hidden_state = getattr(self, f"mapping_{i}")(emb[:, :1]) + getattr(
self, f"mapping_patch_{i}"
)(emb[:, 1:]).mean(dim=1, keepdim=True)
hidden_states += (hidden_state,)
hidden_states = torch.cat(hidden_states, dim=1)

return torch.cat([x, hidden_states], dim=1)
45 changes: 41 additions & 4 deletions scripts/ipadapter/ipadapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
MLPProjModel,
MLPProjModelFaceId,
ProjModelFaceIdPlus,
PuLIDEncoder,
)


Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
is_faceid: bool,
is_portrait: bool,
is_instantid: bool,
is_pulid: bool,
is_v2: bool,
):
super().__init__()
Expand All @@ -85,9 +87,12 @@ def __init__(
self.is_v2 = is_v2
self.is_faceid = is_faceid
self.is_instantid = is_instantid
self.is_pulid = is_pulid
self.clip_extra_context_tokens = 16 if (self.is_plus or is_portrait) else 4

if is_instantid:
if self.is_pulid:
self.image_proj_model = PuLIDEncoder()
elif self.is_instantid:
self.image_proj_model = self.init_proj_instantid()
elif is_faceid:
self.image_proj_model = self.init_proj_faceid()
Expand Down Expand Up @@ -235,6 +240,34 @@ def _get_image_embeds_instantid(
self.image_proj_model(torch.zeros_like(prompt_image_emb)),
)

def _get_image_embeds_pulid(self, pulid_proj_input) -> ImageEmbed:
"""Get image embeds for pulid."""
id_cond = torch.cat(
[
pulid_proj_input.id_ante_embedding.to(
device=self.device, dtype=torch.float32
),
pulid_proj_input.id_cond_vit.to(
device=self.device, dtype=torch.float32
),
],
dim=-1,
)
id_vit_hidden = [
t.to(device=self.device, dtype=torch.float32)
for t in pulid_proj_input.id_vit_hidden
]
return ImageEmbed(
self.image_proj_model(
id_cond,
id_vit_hidden,
),
self.image_proj_model(
torch.zeros_like(id_cond),
[torch.zeros_like(t) for t in id_vit_hidden],
),
)

@staticmethod
def load(state_dict: dict, model_name: str) -> IPAdapterModel:
"""
Expand All @@ -245,6 +278,7 @@ def load(state_dict: dict, model_name: str) -> IPAdapterModel:
is_v2 = "v2" in model_name
is_faceid = "faceid" in model_name
is_instantid = "instant_id" in model_name
is_pulid = "pulid" in model_name.lower()
is_portrait = "portrait" in model_name
is_full = "proj.3.weight" in state_dict["image_proj"]
is_plus = (
Expand All @@ -256,8 +290,8 @@ def load(state_dict: dict, model_name: str) -> IPAdapterModel:
sdxl = cross_attention_dim == 2048
sdxl_plus = sdxl and is_plus

if is_instantid:
# InstantID does not use clip embedding.
if is_instantid or is_pulid:
# InstantID/PuLID does not use clip embedding.
clip_embeddings_dim = None
elif is_faceid:
if is_plus:
Expand Down Expand Up @@ -291,10 +325,13 @@ def load(state_dict: dict, model_name: str) -> IPAdapterModel:
is_portrait=is_portrait,
is_instantid=is_instantid,
is_v2=is_v2,
is_pulid=is_pulid,
)

def get_image_emb(self, preprocessor_output) -> ImageEmbed:
if self.is_instantid:
if self.is_pulid:
return self._get_image_embeds_pulid(preprocessor_output)
elif self.is_instantid:
return self._get_image_embeds_instantid(preprocessor_output)
elif self.is_faceid and self.is_plus:
# Note: FaceID plus uses both face_embed and clip_embed.
Expand Down
Loading

0 comments on commit 784b6d0

Please sign in to comment.