diff --git a/api/backend/backend.py b/api/backend/backend.py index 3beb5285..ce327ba2 100644 --- a/api/backend/backend.py +++ b/api/backend/backend.py @@ -1,12 +1,9 @@ try: import bpy from typing import Callable, List, Tuple + from ..models.generation_arguments import GenerationArguments from ..models.generation_result import GenerationResult - from ..models.task import Task from ..models.model import Model - from ..models.prompt import Prompt - from ..models.seamless_axes import SeamlessAxes - from ..models.step_preview_mode import StepPreviewMode StepCallback = Callable[[GenerationResult], None] Callback = Callable[[List[GenerationResult] | Exception], None] @@ -20,11 +17,7 @@ class Backend(bpy.types.PropertyGroup): def list_models(self) -> List[Model] def generate( self, - task: Task, - model: Model, - prompt: Prompt, - size: Tuple[int, int] | None, - seamless_axes: SeamlessAxes, + arguments: GenerationArguments, step_callback: StepCallback, callback: Callback @@ -91,22 +84,38 @@ def draw_extra(self, layout, context): def generate( self, - task: Task, - model: Model, - prompt: Prompt, - size: Tuple[int, int] | None, - seed: int, - steps: int, - guidance_scale: float, - scheduler: str, - seamless_axes: SeamlessAxes, - step_preview_mode: StepPreviewMode, - iterations: int, - + arguments: GenerationArguments, step_callback: StepCallback, callback: Callback ): """A request to generate an image.""" ... + + def validate( + self, + arguments: GenerationArguments + ): + """Validates the given arguments in the UI without generating. + + This validation should occur as quickly as possible. + + To report problems with the inputs, raise a `ValueError`. + Use the `FixItError` to provide a solution to the problem as well. + + ```python + if arguments.steps % 2 == 0: + throw FixItError( + "The number of steps is even", + solution=FixItError.UpdateGenerationArgumentsSolution( + title="Add 1 more step", + arguments=dataclasses.replace( + arguments, + steps=arguments.steps + 1 + ) + ) + ) + ``` + """ + ... except: pass \ No newline at end of file diff --git a/api/models/__init__.py b/api/models/__init__.py index 5041e364..8219c44b 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -3,4 +3,5 @@ from .prompt import * from .seamless_axes import * from .step_preview_mode import * -from .task import * \ No newline at end of file +from .task import * +from .fix_it_error import * \ No newline at end of file diff --git a/api/models/fix_it_error.py b/api/models/fix_it_error.py new file mode 100644 index 00000000..3292c3f4 --- /dev/null +++ b/api/models/fix_it_error.py @@ -0,0 +1,41 @@ +from typing import Callable, Any +from .generation_arguments import GenerationArguments +from dataclasses import dataclass + +class FixItError(Exception): + """An exception with a solution. + + Call the `draw` method to render the UI elements responsible for resolving this error. + """ + def __init__(self, message, solution: 'Solution'): + super().__init__(message) + + self._solution = solution + + def _draw(self, dream_prompt, context, layout): + self._solution._draw(dream_prompt, context, layout) + + @dataclass + class Solution: + def _draw(self, dream_prompt, context, layout): + ... + + @dataclass + class ChangeProperty(Solution): + """Prompts the user to change the given `property` of the `GenerationArguments`.""" + property: str + + def _draw(self, dream_prompt, context, layout): + layout.prop(dream_prompt, self.property) + + @dataclass + class RunOperator(Solution): + """Runs the given operator""" + title: str + operator: str + modify_operator: Callable[[Any], None] + + def _draw(self, dream_prompt, context, layout): + self.modify_operator( + layout.operator(self.operator, text=self.title) + ) \ No newline at end of file diff --git a/api/models/generation_arguments.py b/api/models/generation_arguments.py new file mode 100644 index 00000000..05406e4b --- /dev/null +++ b/api/models/generation_arguments.py @@ -0,0 +1,104 @@ +from dataclasses import dataclass +from typing import Tuple, List +from ..models.task import Task +from ..models.model import Model +from ..models.prompt import Prompt +from ..models.seamless_axes import SeamlessAxes +from ..models.step_preview_mode import StepPreviewMode + +@dataclass +class GenerationArguments: + task: Task + """The type of generation to perform. + + Use a match statement to perform different actions based on the selected task. + + ```python + match task: + case PromptToImage(): + ... + case ImageToImage(image=image, strength=strength, fit=fit): + ... + case Inpaint(image=image, fit=fit, strength=strength, mask_source=mask_source, mask_prompt=mask_prompt, confidence=confidence): + ... + case DepthToImage(depth=depth, image=image, strength=strength): + ... + case Outpaint(image=image, origin=origin): + ... + case _: + raise NotImplementedError() + ``` + """ + + model: Model + """The selected model. + + This is one of the options provided by `Backend.list_models`. + """ + + prompt: Prompt + """The positive and (optionally) negative prompt. + + If `prompt.negative` is `None`, then the 'Negative Prompt' panel was disabled by the user. + """ + + size: Tuple[int, int] | None + """The target size of the image, or `None` to use the native size of the model.""" + + seed: int + """The random or user-provided seed to use.""" + + steps: int + """The number of inference steps to perform.""" + + guidance_scale: float + """The selected classifier-free guidance scale.""" + + scheduler: str + """The selected scheduler. + + This is one of the options provided by `Backend.list_schedulers`. + """ + + seamless_axes: SeamlessAxes + """Which axes to tile seamlessly.""" + + step_preview_mode: StepPreviewMode + """The style of preview to display at each step.""" + + iterations: int + """The number of images to generate. + + The value sent to `callback` should contain the same number of `GenerationResult` instances in a list. + """ + + @staticmethod + def _map_property_name(name: str) -> str | List[str] | None: + """Converts a property name from `GenerationArguments` to the corresponding property of a `DreamPrompt`.""" + match name: + case "model": + return "model" + case "prompt": + return ["prompt", "use_negative_prompt", "negative_prompt"] + case "prompt.positive": + return "prompt" + case "prompt.negative": + return ["use_negative_prompt", "negative_prompt"] + case "size": + return ["use_size", "width", "height"] + case "seed": + return "seed" + case "steps": + return "steps" + case "guidance_scale": + return "cfg_scale" + case "scheduler": + return "scheduler" + case "seamless_axes": + return "seamless_axes" + case "step_preview_mode": + return "step_preview_mode" + case "iterations": + return "iterations" + case _: + return None \ No newline at end of file diff --git a/api/models/task.py b/api/models/task.py index 1ff11bcc..106fa3b3 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -31,11 +31,17 @@ class Task: ... ``` """ - pass + + @classmethod + def name(cls) -> str: + "unknown" + """A human readable name for this task.""" @dataclass class PromptToImage(Task): - pass + @classmethod + def name(cls): + return "prompt to image" @dataclass class ImageToImage(Task): @@ -43,26 +49,39 @@ class ImageToImage(Task): strength: float fit: bool + @classmethod + def name(cls): + return "image to image" + @dataclass -class Inpaint(Task): +class Inpaint(ImageToImage): class MaskSource(IntEnum): ALPHA = 0 PROMPT = 1 - image: NDArray - strength: float - fit: bool mask_source: MaskSource mask_prompt: str confidence: float + @classmethod + def name(cls): + return "inpainting" + @dataclass class DepthToImage(Task): depth: NDArray | None image: NDArray | None strength: float + @classmethod + def name(cls): + return "depth to image" + @dataclass class Outpaint(Task): image: NDArray - origin: Tuple[int, int] \ No newline at end of file + origin: Tuple[int, int] + + @classmethod + def name(cls): + return "outpainting" \ No newline at end of file diff --git a/diffusers_backend.py b/diffusers_backend.py index 1ab23cb2..368c9560 100644 --- a/diffusers_backend.py +++ b/diffusers_backend.py @@ -1,16 +1,19 @@ +import bpy from bpy.props import FloatProperty, IntProperty, EnumProperty, BoolProperty -from typing import List, Tuple +from typing import List from .api import Backend, StepCallback, Callback -from .api.models import Model, Task, Prompt, SeamlessAxes, GenerationResult, StepPreviewMode +from .api.models import Model, GenerationArguments, GenerationResult from .api.models.task import PromptToImage, ImageToImage, Inpaint, DepthToImage, Outpaint +from .api.models.fix_it_error import FixItError from .generator_process import Generator from .generator_process.actions.prompt_to_image import ImageGenerationResult from .generator_process.future import Future from .generator_process.models import Optimizations, Scheduler from .generator_process.actions.huggingface_hub import ModelType -from .preferences import StableDiffusionPreferences + +from .preferences import StableDiffusionPreferences, _template_model_download_progress, InstallModel from functools import reduce @@ -105,26 +108,26 @@ def optimizations(self) -> Optimizations: optimizations.attention_slice_size = 'auto' return optimizations - def generate(self, task: Task, model: Model, prompt: Prompt, size: Tuple[int, int] | None, seed: int, steps: int, guidance_scale: float, scheduler: str, seamless_axes: SeamlessAxes, step_preview_mode: StepPreviewMode, iterations: int, step_callback: StepCallback, callback: Callback): + def generate(self, arguments: GenerationArguments, step_callback: StepCallback, callback: Callback): gen = Generator.shared() common_kwargs = { - 'model': model.id, - 'scheduler': Scheduler(scheduler), + 'model': arguments.model.id, + 'scheduler': Scheduler(arguments.scheduler), 'optimizations': self.optimizations(), - 'prompt': prompt.positive, - 'steps': steps, - 'width': size[0] if size is not None else None, - 'height': size[1] if size is not None else None, - 'seed': seed, - 'cfg_scale': guidance_scale, - 'use_negative_prompt': prompt.negative is not None, - 'negative_prompt': prompt.negative or "", - 'seamless_axes': seamless_axes, - 'iterations': iterations, - 'step_preview_mode': step_preview_mode, + 'prompt': arguments.prompt.positive, + 'steps': arguments.steps, + 'width': arguments.size[0] if arguments.size is not None else None, + 'height': arguments.size[1] if arguments.size is not None else None, + 'seed': arguments.seed, + 'cfg_scale': arguments.guidance_scale, + 'use_negative_prompt': arguments.prompt.negative is not None, + 'negative_prompt': arguments.prompt.negative or "", + 'seamless_axes': arguments.seamless_axes, + 'iterations': arguments.iterations, + 'step_preview_mode': arguments.step_preview_mode, } future: Future - match task: + match arguments.task: case PromptToImage(): future = gen.prompt_to_image(**common_kwargs) case ImageToImage(image=image, strength=strength, fit=fit): @@ -157,11 +160,11 @@ def generate(self, task: Task, model: Model, prompt: Prompt, size: Tuple[int, in case _: raise NotImplementedError() def on_step(_, step_image: ImageGenerationResult): - step_callback(GenerationResult(progress=step_image.step, total=steps, image=step_image.images[-1], seed=step_image.seeds[-1])) + step_callback(GenerationResult(progress=step_image.step, total=arguments.steps, image=step_image.images[-1], seed=step_image.seeds[-1])) def on_done(future: Future): result: ImageGenerationResult = future.result(last_only=True) callback([ - GenerationResult(progress=result.step, total=steps, image=result.images[i], seed=result.seeds[i]) + GenerationResult(progress=result.step, total=arguments.steps, image=result.images[i], seed=result.seeds[i]) for i in range(len(result.images)) ]) def on_exception(_, exception): @@ -170,6 +173,34 @@ def on_exception(_, exception): future.add_exception_callback(on_exception) future.add_done_callback(on_done) + def validate(self, arguments: GenerationArguments): + installed_models = bpy.context.preferences.addons[StableDiffusionPreferences.bl_idname].preferences.installed_models + model = next((m for m in installed_models if m.model_base == arguments.model.id), None) + if model is None: + raise FixItError("No model selected.", FixItError.ChangeProperty("model")) + else: + if not ModelType[model.model_type].matches_task(arguments.task): + class DownloadModel(FixItError.Solution): + def _draw(self, dream_prompt, context, layout): + if not _template_model_download_progress(context, layout): + target_model_type = ModelType.from_task(arguments.task) + if target_model_type is not None: + install_model = layout.operator(InstallModel.bl_idname, text=f"Download {target_model_type.recommended_model()} (Recommended)", icon="IMPORT") + install_model.model = target_model_type.recommended_model() + install_model.prefer_fp16_revision = context.preferences.addons[StableDiffusionPreferences.bl_idname].preferences.prefer_fp16_revision + model_task_description = f"""Incorrect model type selected for {type(arguments.task).name().replace('_', ' ').lower()} tasks. +The selected model is for {model.model_type.replace('_', ' ').lower()}.""" + if not any(ModelType[m.model_type].matches_task(arguments.task) for m in installed_models): + raise FixItError( + message=model_task_description + "\nYou do not have any compatible models downloaded:", + solution=DownloadModel() + ) + else: + raise FixItError( + message=model_task_description + "\nSelect a different model below.", + solution=FixItError.ChangeProperty("model") + ) + def draw_speed_optimizations(self, layout, context): inferred_device = Optimizations.infer_device() if self.cpu_only: diff --git a/generator_process/actions/huggingface_hub.py b/generator_process/actions/huggingface_hub.py index ebaba892..a3231c8c 100644 --- a/generator_process/actions/huggingface_hub.py +++ b/generator_process/actions/huggingface_hub.py @@ -15,6 +15,7 @@ import json import enum from ..future import Future +from ...api.models.task import * class ModelType(enum.IntEnum): """ @@ -48,6 +49,41 @@ def recommended_model(self) -> str: return "stabilityai/stable-diffusion-2-inpainting" case _: return "stabilityai/stable-diffusion-2-1" + + def matches_task(self, task: Task) -> bool: + """Indicates if the model type is correct for a given `Task`. + + If not an error should be shown to the user to select a different model. + """ + match task: + case PromptToImage(): + return self == ModelType.PROMPT_TO_IMAGE + case Inpaint(): + return self == ModelType.INPAINTING + case DepthToImage(): + return self == ModelType.DEPTH + case Outpaint(): + return self == ModelType.INPAINTING + case ImageToImage(): + return self == ModelType.PROMPT_TO_IMAGE + case _: + return False + + @staticmethod + def from_task(task: Task) -> 'ModelType | None': + match task: + case PromptToImage(): + return ModelType.PROMPT_TO_IMAGE + case Inpaint(): + return ModelType.INPAINTING + case DepthToImage(): + return ModelType.DEPTH + case Outpaint(): + return ModelType.INPAINTING + case ImageToImage(): + return ModelType.PROMPT_TO_IMAGE + case _: + return None @dataclass class Model: diff --git a/generator_process/models/__init__.py b/generator_process/models/__init__.py index cf1a0b83..99bf2a16 100644 --- a/generator_process/models/__init__.py +++ b/generator_process/models/__init__.py @@ -1,4 +1,3 @@ -from .fix_it_error import * from .image_generation_result import * from .optimizations import * from .scheduler import * diff --git a/generator_process/models/fix_it_error.py b/generator_process/models/fix_it_error.py deleted file mode 100644 index 5a19ee1f..00000000 --- a/generator_process/models/fix_it_error.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Callable, Any - -class FixItError(Exception): - """An exception with a solution. - - Call the `draw` method to render the UI elements responsible for resolving this error. - """ - def __init__(self, message, fix_it: Callable[[Any, Any], None]): - super().__init__(message) - - self._fix_it = fix_it - - def draw(self, context, layout): - self._fix_it(context, layout) \ No newline at end of file diff --git a/operators/dream_texture.py b/operators/dream_texture.py index 16b558f7..7e52fad6 100644 --- a/operators/dream_texture.py +++ b/operators/dream_texture.py @@ -33,7 +33,9 @@ class DreamTexture(bpy.types.Operator): @classmethod def poll(cls, context): try: - context.scene.dream_textures_prompt.validate(context) + prompt = context.scene.dream_textures_prompt + backend: api.Backend = prompt.get_backend() + backend.validate(prompt.generate_args(context)) except: return False return Generator.shared().can_use() diff --git a/operators/project.py b/operators/project.py index b22476b4..fd0d4dd9 100644 --- a/operators/project.py +++ b/operators/project.py @@ -17,7 +17,7 @@ from ..preferences import StableDiffusionPreferences from ..generator_process import Generator -from ..generator_process.models import FixItError +from ..api.models import FixItError from ..generator_process.actions.huggingface_hub import ModelType import tempfile @@ -163,7 +163,7 @@ def draw(self, context): error_box.use_property_split = False for i, line in enumerate(e.args[0].split('\n')): error_box.label(text=line, icon="ERROR" if i == 0 else "NONE") - e.draw(context, error_box) + e._draw(context.scene.dream_textures_project_prompt, context, error_box) except Exception as e: print(e) return ActionsPanel diff --git a/property_groups/dream_prompt.py b/property_groups/dream_prompt.py index 70af9580..4a5a59ed 100644 --- a/property_groups/dream_prompt.py +++ b/property_groups/dream_prompt.py @@ -80,6 +80,11 @@ def model_options(self, context): for model in self.get_backend().list_models(context) ] +def _model_update(self, context): + options = [m for m in model_options(self, context) if m is not None] + if self.model == '' and len(options) > 0: + self.model = options[0] + def backend_options(self, context): return [ (backend._id(), backend.name if hasattr(backend, "name") else backend.__name__, backend.description if hasattr(backend, "description") else "") @@ -97,7 +102,7 @@ def seed_clamp(self, ctx): attributes = { "backend": EnumProperty(name="Backend", items=backend_options, default=1, description="Specify which generation backend to use"), - "model": EnumProperty(name="Model", items=model_options, description="Specify which model to use for inference"), + "model": EnumProperty(name="Model", items=model_options, description="Specify which model to use for inference", update=_model_update), "control_nets": CollectionProperty(type=ControlNet), "active_control_net": IntProperty(name="Active ControlNet"), @@ -211,7 +216,7 @@ def get_optimizations(self: DreamPrompt): optimizations.attention_slice_size = 'auto' return optimizations -def generate_args(self, context, iteration=0): +def generate_args(self, context, iteration=0) -> api.GenerationArguments: is_file_batch = self.prompt_structure == file_batch_structure.id file_batch_lines = [] file_batch_lines_negative = [] @@ -255,7 +260,7 @@ def generate_args(self, context, iteration=0): task = api.DepthToImage( depth=np.array(context.scene.init_depth.pixels) .astype(np.float32) - .reshape((scene.init_depth.size[1], scene.init_depth.size[0], scene.init_depth.channels)), + .reshape((context.scene.init_depth.size[1], context.scene.init_depth.size[0], context.scene.init_depth.channels)), image=init_image, strength=self.strength ) @@ -280,22 +285,22 @@ def generate_args(self, context, iteration=0): origin=(self.outpaint_origin[0], self.outpaint_origin[1]) ) - args = { - 'task': task, - 'model': next(model for model in self.get_backend().list_models(context) if model is not None and model.id == self.model), - 'prompt': api.Prompt( + return api.GenerationArguments( + task=task, + model=next(model for model in self.get_backend().list_models(context) if model is not None and model.id == self.model), + prompt=api.Prompt( file_batch_lines if is_file_batch else self.generate_prompt(), file_batch_lines_negative if is_file_batch else (self.negative_prompt if self.use_negative_prompt else None) ), - 'size': (self.width, self.height) if self.use_size else None, - 'seed': self.get_seed(), - 'steps': self.steps, - 'guidance_scale': self.cfg_scale, - 'scheduler': self.scheduler, - 'seamless_axes': SeamlessAxes(self.seamless_axes), - 'step_preview_mode': StepPreviewMode(self.step_preview_mode), - 'iterations': self.iterations - } + size=(self.width, self.height) if self.use_size else None, + seed=self.get_seed(), + steps=self.steps, + guidance_scale=self.cfg_scale, + scheduler=self.scheduler, + seamless_axes=SeamlessAxes(self.seamless_axes), + step_preview_mode=StepPreviewMode(self.step_preview_mode), + iterations=self.iterations + ) # args['control'] = [ # np.flipud( # np.array(net.control_image.pixels) @@ -304,7 +309,6 @@ def generate_args(self, context, iteration=0): # for net in args['control_nets'] # if net.control_image is not None # ] - return args def get_backend(self) -> api.Backend: return getattr(self, api.Backend._lookup(self.backend)._attribute()) diff --git a/property_groups/dream_prompt_validation.py b/property_groups/dream_prompt_validation.py index 69da4857..71fa1a2b 100644 --- a/property_groups/dream_prompt_validation.py +++ b/property_groups/dream_prompt_validation.py @@ -1,5 +1,4 @@ from ..preferences import StableDiffusionPreferences, _template_model_download_progress, InstallModel -from ..generator_process.models import FixItError from ..generator_process.actions.huggingface_hub import ModelType from ..preferences import OpenURL diff --git a/ui/panels/dream_texture.py b/ui/panels/dream_texture.py index 5d4dc464..e93a2ded 100644 --- a/ui/panels/dream_texture.py +++ b/ui/panels/dream_texture.py @@ -16,7 +16,7 @@ from ...property_groups.dream_prompt import DreamPrompt, backend_options from ...generator_process.actions.prompt_to_image import Optimizations from ...generator_process.actions.detect_seamless import SeamlessAxes -from ...generator_process.models import FixItError +from ...api.models import FixItError from ... import api def dream_texture_panels(): @@ -357,13 +357,14 @@ def draw(self, context): # Validation try: - prompt.validate(context) + backend: api.Backend = prompt.get_backend() + backend.validate(prompt.generate_args(context)) except FixItError as e: error_box = layout.box() error_box.use_property_split = False for i, line in enumerate(e.args[0].split('\n')): error_box.label(text=line, icon="ERROR" if i == 0 else "NONE") - e.draw(context, error_box) + e._draw(prompt, context, error_box) except Exception as e: print(e) return ActionsPanel \ No newline at end of file