Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop external_code. prefix #2850

Merged
merged 2 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# Register all preprocessors.
import scripts.preprocessor as preprocessor_init # noqa
from annotator.util import HWC3
from internal_controlnet.external_code import ControlNetUnit
from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils
from scripts.controlnet_lora import bind_control_lora, unbind_control_lora
from scripts.controlnet_lllite import clear_all_lllite
Expand Down Expand Up @@ -228,7 +229,7 @@ def get_pytorch_control(x: np.ndarray) -> torch.Tensor:

def get_control(
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
unit: ControlNetUnit,
idx: int,
control_model_type: ControlModelType,
preprocessor: Preprocessor,
Expand Down Expand Up @@ -338,7 +339,7 @@ def __init__(self) -> None:
self.latest_network = None
self.input_image = None
self.latest_model_hash = ""
self.enabled_units: List[external_code.ControlNetUnit] = []
self.enabled_units: List[ControlNetUnit] = []
self.detected_map = []
self.post_processors = []
self.noise_modifier = None
Expand All @@ -356,7 +357,7 @@ def show(self, is_img2img):

@staticmethod
def get_default_ui_unit(is_ui=True):
cls = UiControlNetUnit if is_ui else external_code.ControlNetUnit
cls = UiControlNetUnit if is_ui else ControlNetUnit
return cls(
enabled=False,
module="none",
Expand Down Expand Up @@ -527,7 +528,7 @@ def get_element(obj, strict=False):
return attribute_value if attribute_value is not None else default

@staticmethod
def parse_remote_call(p, unit: external_code.ControlNetUnit, idx):
def parse_remote_call(p, unit: ControlNetUnit, idx):
selector = Script.get_remote_call

unit.enabled = selector(p, "control_net_enabled", unit.enabled, idx, strict=True)
Expand Down Expand Up @@ -688,7 +689,7 @@ def get_enabled_units(p):
@staticmethod
def choose_input_image(
p: processing.StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
unit: ControlNetUnit,
idx: int
) -> Tuple[np.ndarray, ResizeMode]:
""" Choose input image from following sources with descending priority:
Expand All @@ -701,7 +702,7 @@ def choose_input_image(
- The input image in ndarray form.
- The resize mode.
"""
def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
def parse_unit_image(unit: ControlNetUnit) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
unit_has_multiple_images = (
isinstance(unit.image, list) and
len(unit.image) > 0 and
Expand Down Expand Up @@ -810,7 +811,7 @@ def decode_image(img) -> np.ndarray:
@staticmethod
def try_crop_image_with_a1111_mask(
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
unit: ControlNetUnit,
input_image: np.ndarray,
resize_mode: ResizeMode,
) -> np.ndarray:
Expand Down Expand Up @@ -863,7 +864,7 @@ def try_crop_image_with_a1111_mask(
return input_image

@staticmethod
def check_sd_version_compatible(unit: external_code.ControlNetUnit) -> None:
def check_sd_version_compatible(unit: ControlNetUnit) -> None:
"""
Checks whether the given ControlNet unit has model compatible with the currently
active sd model. An exception is thrown if ControlNet unit is detected to be
Expand Down
9 changes: 5 additions & 4 deletions scripts/controlnet_ui/controlnet_ui_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
external_code,
)
from annotator.util import HWC3
from internal_controlnet.external_code import ControlNetUnit
from scripts.logging import logger
from scripts.controlnet_ui.openpose_editor import OpenposeEditor
from scripts.controlnet_ui.preset import ControlNetPresetUI
Expand Down Expand Up @@ -127,7 +128,7 @@ def set_component(self, component: gr.components.Component):
)


class UiControlNetUnit(external_code.ControlNetUnit):
class UiControlNetUnit(ControlNetUnit):
"""The data class that stores all states of a ControlNetUnit."""

def __init__(
Expand Down Expand Up @@ -167,7 +168,7 @@ def __init__(
self.output_dir = output_dir
self.loopback = loopback

def unfold_merged(self) -> List[external_code.ControlNetUnit]:
def unfold_merged(self) -> List[ControlNetUnit]:
"""Unfolds a merged unit to multiple units. Keeps the unit merged for
preprocessors that can accept multiple input images.
"""
Expand Down Expand Up @@ -220,7 +221,7 @@ class ControlNetUiGroup(object):
def __init__(
self,
is_img2img: bool,
default_unit: external_code.ControlNetUnit,
default_unit: ControlNetUnit,
photopea: Optional[Photopea],
):
# Whether callbacks have been registered.
Expand Down Expand Up @@ -1260,7 +1261,7 @@ def register_core_callbacks(self):
self.type_filter,
*[
getattr(self, key)
for key in vars(external_code.ControlNetUnit()).keys()
for key in vars(ControlNetUnit()).keys()
],
)
self.advanced_weight_control.register_callbacks(
Expand Down
16 changes: 8 additions & 8 deletions scripts/controlnet_ui/preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from modules import scripts
from modules.ui_components import ToolButton
from internal_controlnet.external_code import ControlNetUnit
from scripts.infotext import parse_unit, serialize_unit
from scripts.logging import logger
from scripts import external_code
from scripts.supported_preprocessor import Preprocessor

save_symbol = "\U0001f4be" # 💾
Expand Down Expand Up @@ -113,15 +113,15 @@ def apply_preset(name: str, control_type: str, *ui_states):
gr.update(visible=False),
*(
(gr.skip(),)
* (len(vars(external_code.ControlNetUnit()).keys()) + 1)
* (len(vars(ControlNetUnit()).keys()) + 1)
),
)

assert name in ControlNetPresetUI.presets

infotext = ControlNetPresetUI.presets[name]
preset_unit = parse_unit(infotext)
current_unit = external_code.ControlNetUnit(*ui_states)
current_unit = ControlNetUnit(*ui_states)
preset_unit.image = None
current_unit.image = None

Expand All @@ -136,7 +136,7 @@ def apply_preset(name: str, control_type: str, *ui_states):
gr.update(visible=False),
*(
(gr.skip(),)
* (len(vars(external_code.ControlNetUnit()).keys()) + 1)
* (len(vars(ControlNetUnit()).keys()) + 1)
),
)

Expand Down Expand Up @@ -177,7 +177,7 @@ def save_preset(name: str, *ui_states):
return gr.update(visible=True), gr.update(), gr.update()

ControlNetPresetUI.save_preset(
name, external_code.ControlNetUnit(*ui_states)
name, ControlNetUnit(*ui_states)
)
return (
gr.update(), # name dialog
Expand Down Expand Up @@ -222,7 +222,7 @@ def save_new_preset(new_name: str, *ui_states):
return gr.update(visible=False), gr.update()

ControlNetPresetUI.save_preset(
new_name, external_code.ControlNetUnit(*ui_states)
new_name, ControlNetUnit(*ui_states)
)
return gr.update(visible=False), gr.update(
choices=ControlNetPresetUI.dropdown_choices(), value=new_name
Expand All @@ -248,7 +248,7 @@ def update_reset_button(preset_name: str, *ui_states):

infotext = ControlNetPresetUI.presets[preset_name]
preset_unit = parse_unit(infotext)
current_unit = external_code.ControlNetUnit(*ui_states)
current_unit = ControlNetUnit(*ui_states)
preset_unit.image = None
current_unit.image = None

Expand Down Expand Up @@ -279,7 +279,7 @@ def dropdown_choices() -> List[str]:
return list(ControlNetPresetUI.presets.keys()) + [NEW_PRESET]

@staticmethod
def save_preset(name: str, unit: external_code.ControlNetUnit):
def save_preset(name: str, unit: ControlNetUnit):
infotext = serialize_unit(unit)
with open(
os.path.join(ControlNetPresetUI.preset_directory, f"{name}.txt"), "w"
Expand Down
16 changes: 8 additions & 8 deletions scripts/infotext.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from modules.processing import StableDiffusionProcessing

from scripts import external_code
from internal_controlnet.external_code import ControlNetUnit
from scripts.logging import logger


Expand All @@ -28,12 +28,12 @@ def parse_value(value: str) -> Union[str, float, int, bool]:
return value # Plain string.


def serialize_unit(unit: external_code.ControlNetUnit) -> str:
excluded_fields = external_code.ControlNetUnit.infotext_excluded_fields()
def serialize_unit(unit: ControlNetUnit) -> str:
excluded_fields = ControlNetUnit.infotext_excluded_fields()

log_value = {
field_to_displaytext(field): getattr(unit, field)
for field in vars(external_code.ControlNetUnit()).keys()
for field in vars(ControlNetUnit()).keys()
if field not in excluded_fields and getattr(unit, field) != -1
# Note: exclude hidden slider values.
}
Expand All @@ -44,8 +44,8 @@ def serialize_unit(unit: external_code.ControlNetUnit) -> str:
return ", ".join(f"{field}: {value}" for field, value in log_value.items())


def parse_unit(text: str) -> external_code.ControlNetUnit:
return external_code.ControlNetUnit(
def parse_unit(text: str) -> ControlNetUnit:
return ControlNetUnit(
enabled=True,
**{
displaytext_to_field(key): parse_value(value)
Expand Down Expand Up @@ -74,7 +74,7 @@ def register_unit(self, unit_index: int, uigroup) -> None:
iocomponents.
"""
unit_prefix = Infotext.unit_prefix(unit_index)
for field in vars(external_code.ControlNetUnit()).keys():
for field in vars(ControlNetUnit()).keys():
# Exclude image for infotext.
if field == "image":
continue
Expand All @@ -88,7 +88,7 @@ def register_unit(self, unit_index: int, uigroup) -> None:

@staticmethod
def write_infotext(
units: List[external_code.ControlNetUnit], p: StableDiffusionProcessing
units: List[ControlNetUnit], p: StableDiffusionProcessing
):
"""Write infotext to `p`."""
p.extra_generation_params.update(
Expand Down
15 changes: 8 additions & 7 deletions tests/cn_script/batch_hijack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


from modules import processing, scripts, shared
from internal_controlnet.external_code import ControlNetUnit
from scripts import controlnet, external_code, batch_hijack


Expand Down Expand Up @@ -73,15 +74,15 @@ def test_get_cn_batches__empty(self):
self.assertEqual(is_batch, False)

def test_get_cn_batches__1_simple(self):
self.p.script_args.append(external_code.ControlNetUnit(image=get_dummy_image()))
self.p.script_args.append(ControlNetUnit(image=get_dummy_image()))
self.assert_get_cn_batches_works([
[self.p.script_args[0].image],
])

def test_get_cn_batches__2_simples(self):
self.p.script_args.extend([
external_code.ControlNetUnit(image=get_dummy_image(0)),
external_code.ControlNetUnit(image=get_dummy_image(1)),
ControlNetUnit(image=get_dummy_image(0)),
ControlNetUnit(image=get_dummy_image(1)),
])
self.assert_get_cn_batches_works([
[get_dummy_image(0)],
Expand Down Expand Up @@ -135,7 +136,7 @@ def test_get_cn_batches__2_batches(self):

def test_get_cn_batches__2_mixed(self):
self.p.script_args.extend([
external_code.ControlNetUnit(image=get_dummy_image(0)),
ControlNetUnit(image=get_dummy_image(0)),
controlnet.UiControlNetUnit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
Expand All @@ -157,7 +158,7 @@ def test_get_cn_batches__2_mixed(self):

def test_get_cn_batches__3_mixed(self):
self.p.script_args.extend([
external_code.ControlNetUnit(image=get_dummy_image(0)),
ControlNetUnit(image=get_dummy_image(0)),
controlnet.UiControlNetUnit(
input_mode=batch_hijack.InputMode.BATCH,
batch_images=[
Expand Down Expand Up @@ -242,8 +243,8 @@ def test_process_images_no_units_forwards(self):

def test_process_images__only_simple_units__forwards(self):
self.p.script_args = [
external_code.ControlNetUnit(image=get_dummy_image()),
external_code.ControlNetUnit(image=get_dummy_image()),
ControlNetUnit(image=get_dummy_image()),
ControlNetUnit(image=get_dummy_image()),
]
self.assert_process_images_hijack_called(batch_count=0)

Expand Down
7 changes: 4 additions & 3 deletions tests/cn_script/cn_script_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scripts import external_code
from scripts.enums import ResizeMode
from scripts.controlnet import prepare_mask, Script, set_numpy_seed
from internal_controlnet.external_code import ControlNetUnit
from modules import processing


Expand Down Expand Up @@ -127,7 +128,7 @@ def test_choose_input_image(self):
with self.assertRaises(ValueError):
Script.choose_input_image(
p=processing.StableDiffusionProcessing(),
unit=external_code.ControlNetUnit(),
unit=ControlNetUnit(),
idx=0,
)

Expand All @@ -137,7 +138,7 @@ def test_choose_input_image(self):
init_images=[TestScript.sample_np_image],
resize_mode=ResizeMode.OUTER_FIT,
),
unit=external_code.ControlNetUnit(
unit=ControlNetUnit(
image=TestScript.sample_base64_image,
module="none",
resize_mode=ResizeMode.INNER_FIT,
Expand All @@ -152,7 +153,7 @@ def test_choose_input_image(self):
init_images=[TestScript.sample_np_image],
resize_mode=ResizeMode.OUTER_FIT,
),
unit=external_code.ControlNetUnit(
unit=ControlNetUnit(
module="none",
resize_mode=ResizeMode.INNER_FIT,
),
Expand Down
Loading
Loading