Skip to content

Commit

Permalink
Fix Lora from text input #540
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Mar 29, 2024
1 parent 87815e3 commit a8a4ec3
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
1 change: 1 addition & 0 deletions ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class ClientModels:

def __init__(self) -> None:
self.node_inputs = {}
self.resources = {}

def resource(
self, kind: ResourceKind, identifier: ControlMode | UpscalerName | str, version: SDVersion
Expand Down
5 changes: 3 additions & 2 deletions ai_diffusion/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Tuple, List

from .api import LoraInput
from .util import client_logger as log


Expand All @@ -20,7 +21,7 @@ def merge_prompt(prompt: str, style_prompt: str):


def extract_loras(prompt: str, client_loras: list[str]):
loras = []
loras: list[LoraInput] = []
for match in _pattern_lora.findall(prompt):
lora_name = ""

Expand All @@ -42,7 +43,7 @@ def extract_loras(prompt: str, client_loras: list[str]):
log.warning(error)
raise Exception(error)

loras.append(dict(name=lora_name, strength=lora_strength))
loras.append(LoraInput(lora_name, lora_strength))
return _pattern_lora.sub("", prompt), loras


Expand Down
29 changes: 27 additions & 2 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from typing import Any

from ai_diffusion import workflow
from ai_diffusion.api import WorkflowKind, WorkflowInput, ControlInput, InpaintMode, FillMode
from ai_diffusion.api import TextInput
from ai_diffusion.api import LoraInput, WorkflowKind, WorkflowInput, ControlInput
from ai_diffusion.api import InpaintMode, FillMode, TextInput
from ai_diffusion.client import ClientModels, CheckpointInfo
from ai_diffusion.comfy_client import ComfyClient
from ai_diffusion.cloud_client import CloudClient
from ai_diffusion.resources import ControlMode
Expand Down Expand Up @@ -146,6 +147,30 @@ def test_inpaint_params():
assert e.use_condition_mask == False


def test_prepare_lora():
models = ClientModels()
models.checkpoints = {"CP": CheckpointInfo("CP", SDVersion.sd15)}
models.loras = ["PINK_UNICORNS", "MOTHER_OF_PEARL"]
style = Style(Path("default.json"))
style.sd_checkpoint = "CP"
style.loras.append(dict(name="MOTHER_OF_PEARL", strength=0.33))
job = workflow.prepare(
WorkflowKind.generate,
canvas=Extent(512, 512),
text=TextInput("test <lora:PINK_UNICORNS:0.77>"),
style=style,
seed=29,
models=models,
perf=default_perf,
)
assert job.text and job.text.positive == "test"
assert (
job.models
and LoraInput("PINK_UNICORNS", 0.77) in job.models.loras
and LoraInput("MOTHER_OF_PEARL", 0.33) in job.models.loras
)


@pytest.mark.parametrize("extent", [Extent(256, 256), Extent(800, 800), Extent(512, 1024)])
def test_generate(qtapp, client, extent: Extent):
prompt = TextInput("ship")
Expand Down

0 comments on commit a8a4ec3

Please sign in to comment.