Skip to content

Commit

Permalink
Drop external_code. prefix (#2850)
Browse files Browse the repository at this point in the history
* Drop external_code. prefix

* Remove unused imports
  • Loading branch information
huchenlei authored May 5, 2024
1 parent 2c59fec commit abafad2
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 48 deletions.
17 changes: 9 additions & 8 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# Register all preprocessors.
import scripts.preprocessor as preprocessor_init # noqa
from annotator.util import HWC3
from internal_controlnet.external_code import ControlNetUnit
from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils
from scripts.controlnet_lora import bind_control_lora, unbind_control_lora
from scripts.controlnet_lllite import clear_all_lllite
Expand Down Expand Up @@ -228,7 +229,7 @@ def get_pytorch_control(x: np.ndarray) -> torch.Tensor:

def get_control(
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
unit: ControlNetUnit,
idx: int,
control_model_type: ControlModelType,
preprocessor: Preprocessor,
Expand Down Expand Up @@ -338,7 +339,7 @@ def __init__(self) -> None:
self.latest_network = None
self.input_image = None
self.latest_model_hash = ""
self.enabled_units: List[external_code.ControlNetUnit] = []
self.enabled_units: List[ControlNetUnit] = []
self.detected_map = []
self.post_processors = []
self.noise_modifier = None
Expand All @@ -356,7 +357,7 @@ def show(self, is_img2img):

@staticmethod
def get_default_ui_unit(is_ui=True):
cls = UiControlNetUnit if is_ui else external_code.ControlNetUnit
cls = UiControlNetUnit if is_ui else ControlNetUnit
return cls(
enabled=False,
module="none",
Expand Down Expand Up @@ -527,7 +528,7 @@ def get_element(obj, strict=False):
return attribute_value if attribute_value is not None else default

@staticmethod
def parse_remote_call(p, unit: external_code.ControlNetUnit, idx):
def parse_remote_call(p, unit: ControlNetUnit, idx):
selector = Script.get_remote_call

unit.enabled = selector(p, "control_net_enabled", unit.enabled, idx, strict=True)
Expand Down Expand Up @@ -688,7 +689,7 @@ def get_enabled_units(p):
@staticmethod
def choose_input_image(
p: processing.StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
unit: ControlNetUnit,
idx: int
) -> Tuple[np.ndarray, ResizeMode]:
""" Choose input image from following sources with descending priority:
Expand All @@ -701,7 +702,7 @@ def choose_input_image(
- The input image in ndarray form.
- The resize mode.
"""
def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
def parse_unit_image(unit: ControlNetUnit) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
unit_has_multiple_images = (
isinstance(unit.image, list) and
len(unit.image) > 0 and
Expand Down Expand Up @@ -810,7 +811,7 @@ def decode_image(img) -> np.ndarray:
@staticmethod
def try_crop_image_with_a1111_mask(
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
unit: ControlNetUnit,
input_image: np.ndarray,
resize_mode: ResizeMode,
) -> np.ndarray:
Expand Down Expand Up @@ -863,7 +864,7 @@ def try_crop_image_with_a1111_mask(
return input_image

@staticmethod
def check_sd_version_compatible(unit: external_code.ControlNetUnit) -> None:
def check_sd_version_compatible(unit: ControlNetUnit) -> None:
"""
Checks whether the given ControlNet unit has model compatible with the currently
active sd model. An exception is thrown if ControlNet unit is detected to be
Expand Down
9 changes: 5 additions & 4 deletions scripts/controlnet_ui/controlnet_ui_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
external_code,
)
from annotator.util import HWC3
from internal_controlnet.external_code import ControlNetUnit
from scripts.logging import logger
from scripts.controlnet_ui.openpose_editor import OpenposeEditor
from scripts.controlnet_ui.preset import ControlNetPresetUI
Expand Down Expand Up @@ -127,7 +128,7 @@ def set_component(self, component: gr.components.Component):
)


class UiControlNetUnit(external_code.ControlNetUnit):
class UiControlNetUnit(ControlNetUnit):
"""The data class that stores all states of a ControlNetUnit."""

def __init__(
Expand Down Expand Up @@ -167,7 +168,7 @@ def __init__(
self.output_dir = output_dir
self.loopback = loopback

def unfold_merged(self) -> List[external_code.ControlNetUnit]:
def unfold_merged(self) -> List[ControlNetUnit]:
"""Unfolds a merged unit to multiple units. Keeps the unit merged for
preprocessors that can accept multiple input images.
"""
Expand Down Expand Up @@ -220,7 +221,7 @@ class ControlNetUiGroup(object):
def __init__(
self,
is_img2img: bool,
default_unit: external_code.ControlNetUnit,
default_unit: ControlNetUnit,
photopea: Optional[Photopea],
):
# Whether callbacks have been registered.
Expand Down Expand Up @@ -1260,7 +1261,7 @@ def register_core_callbacks(self):
self.type_filter,
*[
getattr(self, key)
for key in vars(external_code.ControlNetUnit()).keys()
for key in vars(ControlNetUnit()).keys()
],
)
self.advanced_weight_control.register_callbacks(
Expand Down
16 changes: 8 additions & 8 deletions scripts/controlnet_ui/preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from modules import scripts
from modules.ui_components import ToolButton
from internal_controlnet.external_code import ControlNetUnit
from scripts.infotext import parse_unit, serialize_unit
from scripts.logging import logger
from scripts import external_code
from scripts.supported_preprocessor import Preprocessor

save_symbol = "\U0001f4be" # 💾
Expand Down Expand Up @@ -113,15 +113,15 @@ def apply_preset(name: str, control_type: str, *ui_states):
gr.update(visible=False),
*(
(gr.skip(),)
* (len(vars(external_code.ControlNetUnit()).keys()) + 1)
* (len(vars(ControlNetUnit()).keys()) + 1)
),
)

assert name in ControlNetPresetUI.presets

infotext = ControlNetPresetUI.presets[name]
preset_unit = parse_unit(infotext)
current_unit = external_code.ControlNetUnit(*ui_states)
current_unit = ControlNetUnit(*ui_states)
preset_unit.image = None
current_unit.image = None

Expand All @@ -136,7 +136,7 @@ def apply_preset(name: str, control_type: str, *ui_states):
gr.update(visible=False),
*(
(gr.skip(),)
* (len(vars(external_code.ControlNetUnit()).keys()) + 1)
* (len(vars(ControlNetUnit()).keys()) + 1)
),
)

Expand Down Expand Up @@ -177,7 +177,7 @@ def save_preset(name: str, *ui_states):
return gr.update(visible=True), gr.update(), gr.update()

ControlNetPresetUI.save_preset(
name, external_code.ControlNetUnit(*ui_states)
name, ControlNetUnit(*ui_states)
)
return (
gr.update(), # name dialog
Expand Down Expand Up @@ -222,7 +222,7 @@ def save_new_preset(new_name: str, *ui_states):
return gr.update(visible=False), gr.update()

ControlNetPresetUI.save_preset(
new_name, external_code.ControlNetUnit(*ui_states)
new_name, ControlNetUnit(*ui_states)
)
return gr.update(visible=False), gr.update(
choices=ControlNetPresetUI.dropdown_choices(), value=new_name
Expand All @@ -248,7 +248,7 @@ def update_reset_button(preset_name: str, *ui_states):

infotext = ControlNetPresetUI.presets[preset_name]
preset_unit = parse_unit(infotext)
current_unit = external_code.ControlNetUnit(*ui_states)
current_unit = ControlNetUnit(*ui_states)
preset_unit.image = None
current_unit.image = None

Expand Down Expand Up @@ -279,7 +279,7 @@ def dropdown_choices() -> List[str]:
return list(ControlNetPresetUI.presets.keys()) + [NEW_PRESET]

@staticmethod
def save_preset(name: str, unit: external_code.ControlNetUnit):
def save_preset(name: str, unit: ControlNetUnit):
infotext = serialize_unit(unit)
with open(
os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt"), "w"
Expand Down
16 changes: 8 additions & 8 deletions scripts/infotext.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from modules.processing import StableDiffusionProcessing

from scripts import external_code
from internal_controlnet.external_code import ControlNetUnit
from scripts.logging import logger


Expand All @@ -28,12 +28,12 @@ def parse_value(value: str) -> Union[str, float, int, bool]:
return value # Plain string.


def serialize_unit(unit: external_code.ControlNetUnit) -> str:
excluded_fields = external_code.ControlNetUnit.infotext_excluded_fields()
def serialize_unit(unit: ControlNetUnit) -> str:
excluded_fields = ControlNetUnit.infotext_excluded_fields()

log_value = {
field_to_displaytext(field): getattr(unit, field)
for field in vars(external_code.ControlNetUnit()).keys()
for field in vars(ControlNetUnit()).keys()
if field not in excluded_fields and getattr(unit, field) != -1
# Note: exclude hidden slider values.
}
Expand All @@ -44,8 +44,8 @@ def serialize_unit(unit: external_code.ControlNetUnit) -> str:
return ", ".join(f"{field}: {value}" for field, value in log_value.items())


def parse_unit(text: str) -> external_code.ControlNetUnit:
return external_code.ControlNetUnit(
def parse_unit(text: str) -> ControlNetUnit:
return ControlNetUnit(
enabled=True,
**{
displaytext_to_field(key): parse_value(value)
Expand Down Expand Up @@ -74,7 +74,7 @@ def register_unit(self, unit_index: int, uigroup) -> None:
iocomponents.
"""
unit_prefix = Infotext.unit_prefix(unit_index)
for field in vars(external_code.ControlNetUnit()).keys():
for field in vars(ControlNetUnit()).keys():
# Exclude image for infotext.
if field == "image":
continue
Expand All @@ -88,7 +88,7 @@ def register_unit(self, unit_index: int, uigroup) -> None:

@staticmethod
def write_infotext(
units: List[external_code.ControlNetUnit], p: StableDiffusionProcessing
units: List[ControlNetUnit], p: StableDiffusionProcessing
):
"""Write infotext to `p`."""
p.extra_generation_params.update(
Expand Down
15 changes: 8 additions & 7 deletions tests/cn_script/batch_hijack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


from modules import processing, scripts, shared
from internal_controlnet.external_code import ControlNetUnit
from scripts import controlnet, external_code, batch_hijack


Expand Down Expand Up @@ -73,15 +74,15 @@ def test_get_cn_batches__empty(self):
self.assertEqual(is_batch, False)

def test_get_cn_batches__1_simple(self):
self.p.script_args.append(external_code.ControlNetUnit(image=get_dummy_image()))
self.p.script_args.append(ControlNetUnit(image=get_dummy_image()))
self.assert_get_cn_batches_works([
[self.p.script_args[0].image],
])

def test_get_cn_batches__2_simples(self):
self.p.script_args.extend([
external_code.ControlNetUnit(image=get_dummy_image(0)),
external_code.ControlNetUnit(image=get_dummy_image(1)),
ControlNetUnit(image=get_dummy_image(0)),
ControlNetUnit(image=get_dummy_image(1)),
])
self.assert_get_cn_batches_works([
[get_dummy_image(0)],
Expand Down Expand Up @@ -135,7 +136,7 @@ def test_get_cn_batches__2_batches(self):

def test_get_cn_batches__2_mixed(self):
self.p.script_args.extend([
external_code.ControlNetUnit(image=get_dummy_image(0)),
ControlNetUnit(image=get_dummy_image(0)),
controlnet.UiControlNetUnit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
Expand All @@ -157,7 +158,7 @@ def test_get_cn_batches__2_mixed(self):

def test_get_cn_batches__3_mixed(self):
self.p.script_args.extend([
external_code.ControlNetUnit(image=get_dummy_image(0)),
ControlNetUnit(image=get_dummy_image(0)),
controlnet.UiControlNetUnit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
Expand Down Expand Up @@ -242,8 +243,8 @@ def test_process_images_no_units_forwards(self):

def test_process_images__only_simple_units__forwards(self):
self.p.script_args = [
external_code.ControlNetUnit(image=get_dummy_image()),
external_code.ControlNetUnit(image=get_dummy_image()),
ControlNetUnit(image=get_dummy_image()),
ControlNetUnit(image=get_dummy_image()),
]
self.assert_process_images_hijack_called(batch_count=0)

Expand Down
7 changes: 4 additions & 3 deletions tests/cn_script/cn_script_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scripts import external_code
from scripts.enums import ResizeMode
from scripts.controlnet import prepare_mask, Script, set_numpy_seed
from internal_controlnet.external_code import ControlNetUnit
from modules import processing


Expand Down Expand Up @@ -127,7 +128,7 @@ def test_choose_input_image(self):
with self.assertRaises(ValueError):
Script.choose_input_image(
p=processing.StableDiffusionProcessing(),
unit=external_code.ControlNetUnit(),
unit=ControlNetUnit(),
idx=0,
)

Expand All @@ -137,7 +138,7 @@ def test_choose_input_image(self):
init_images=[TestScript.sample_np_image],
resize_mode=ResizeMode.OUTER_FIT,
),
unit=external_code.ControlNetUnit(
unit=ControlNetUnit(
image=TestScript.sample_base64_image,
module="none",
resize_mode=ResizeMode.INNER_FIT,
Expand All @@ -152,7 +153,7 @@ def test_choose_input_image(self):
init_images=[TestScript.sample_np_image],
resize_mode=ResizeMode.OUTER_FIT,
),
unit=external_code.ControlNetUnit(
unit=ControlNetUnit(
module="none",
resize_mode=ResizeMode.INNER_FIT,
),
Expand Down
Loading

0 comments on commit abafad2

Please sign in to comment.