Skip to content

Commit

Permalink
Allow backends to report fixable errors
Browse files Browse the repository at this point in the history
  • Loading branch information
carson-katri committed May 21, 2023
1 parent e8113d0 commit c8ef88a
Show file tree
Hide file tree
Showing 14 changed files with 320 additions and 88 deletions.
51 changes: 30 additions & 21 deletions api/backend/backend.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
try:
import bpy
from typing import Callable, List, Tuple
from ..models.generation_arguments import GenerationArguments
from ..models.generation_result import GenerationResult
from ..models.task import Task
from ..models.model import Model
from ..models.prompt import Prompt
from ..models.seamless_axes import SeamlessAxes
from ..models.step_preview_mode import StepPreviewMode

StepCallback = Callable[[GenerationResult], None]
Callback = Callable[[List[GenerationResult] | Exception], None]
Expand All @@ -20,11 +17,7 @@ class Backend(bpy.types.PropertyGroup):
def list_models(self) -> List[Model]
def generate(
self,
task: Task,
model: Model,
prompt: Prompt,
size: Tuple[int, int] | None,
seamless_axes: SeamlessAxes,
arguments: GenerationArguments,
step_callback: StepCallback,
callback: Callback
Expand Down Expand Up @@ -91,22 +84,38 @@ def draw_extra(self, layout, context):

def generate(
self,
task: Task,
model: Model,
prompt: Prompt,
size: Tuple[int, int] | None,
seed: int,
steps: int,
guidance_scale: float,
scheduler: str,
seamless_axes: SeamlessAxes,
step_preview_mode: StepPreviewMode,
iterations: int,

arguments: GenerationArguments,
step_callback: StepCallback,
callback: Callback
):
"""A request to generate an image."""
...

def validate(
self,
arguments: GenerationArguments
):
"""Validates the given arguments in the UI without generating.
This validation should occur as quickly as possible.
To report problems with the inputs, raise a `ValueError`.
Use the `FixItError` to provide a solution to the problem as well.
```python
if arguments.steps % 2 == 0:
throw FixItError(
"The number of steps is even",
solution=FixItError.UpdateGenerationArgumentsSolution(
title="Add 1 more step",
arguments=dataclasses.replace(
arguments,
steps=arguments.steps + 1
)
)
)
```
"""
...
except:
pass
3 changes: 2 additions & 1 deletion api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .prompt import *
from .seamless_axes import *
from .step_preview_mode import *
from .task import *
from .task import *
from .fix_it_error import *
41 changes: 41 additions & 0 deletions api/models/fix_it_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Callable, Any
from .generation_arguments import GenerationArguments
from dataclasses import dataclass

class FixItError(Exception):
"""An exception with a solution.
Call the `draw` method to render the UI elements responsible for resolving this error.
"""
def __init__(self, message, solution: 'Solution'):
super().__init__(message)

self._solution = solution

def _draw(self, dream_prompt, context, layout):
self._solution._draw(dream_prompt, context, layout)

@dataclass
class Solution:
def _draw(self, dream_prompt, context, layout):
...

@dataclass
class ChangeProperty(Solution):
"""Prompts the user to change the given `property` of the `GenerationArguments`."""
property: str

def _draw(self, dream_prompt, context, layout):
layout.prop(dream_prompt, self.property)

@dataclass
class RunOperator(Solution):
"""Runs the given operator"""
title: str
operator: str
modify_operator: Callable[[Any], None]

def _draw(self, dream_prompt, context, layout):
self.modify_operator(
layout.operator(self.operator, text=self.title)
)
104 changes: 104 additions & 0 deletions api/models/generation_arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from dataclasses import dataclass
from typing import Tuple, List
from ..models.task import Task
from ..models.model import Model
from ..models.prompt import Prompt
from ..models.seamless_axes import SeamlessAxes
from ..models.step_preview_mode import StepPreviewMode

@dataclass
class GenerationArguments:
task: Task
"""The type of generation to perform.
Use a match statement to perform different actions based on the selected task.
```python
match task:
case PromptToImage():
...
case ImageToImage(image=image, strength=strength, fit=fit):
...
case Inpaint(image=image, fit=fit, strength=strength, mask_source=mask_source, mask_prompt=mask_prompt, confidence=confidence):
...
case DepthToImage(depth=depth, image=image, strength=strength):
...
case Outpaint(image=image, origin=origin):
...
case _:
raise NotImplementedError()
```
"""

model: Model
"""The selected model.
This is one of the options provided by `Backend.list_models`.
"""

prompt: Prompt
"""The positive and (optionally) negative prompt.
If `prompt.negative` is `None`, then the 'Negative Prompt' panel was disabled by the user.
"""

size: Tuple[int, int] | None
"""The target size of the image, or `None` to use the native size of the model."""

seed: int
"""The random or user-provided seed to use."""

steps: int
"""The number of inference steps to perform."""

guidance_scale: float
"""The selected classifier-free guidance scale."""

scheduler: str
"""The selected scheduler.
This is one of the options provided by `Backend.list_schedulers`.
"""

seamless_axes: SeamlessAxes
"""Which axes to tile seamlessly."""

step_preview_mode: StepPreviewMode
"""The style of preview to display at each step."""

iterations: int
"""The number of images to generate.
The value sent to `callback` should contain the same number of `GenerationResult` instances in a list.
"""

@staticmethod
def _map_property_name(name: str) -> str | List[str] | None:
"""Converts a property name from `GenerationArguments` to the corresponding property of a `DreamPrompt`."""
match name:
case "model":
return "model"
case "prompt":
return ["prompt", "use_negative_prompt", "negative_prompt"]
case "prompt.positive":
return "prompt"
case "prompt.negative":
return ["use_negative_prompt", "negative_prompt"]
case "size":
return ["use_size", "width", "height"]
case "seed":
return "seed"
case "steps":
return "steps"
case "guidance_scale":
return "cfg_scale"
case "scheduler":
return "scheduler"
case "seamless_axes":
return "seamless_axes"
case "step_preview_mode":
return "step_preview_mode"
case "iterations":
return "iterations"
case _:
return None
33 changes: 26 additions & 7 deletions api/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,38 +31,57 @@ class Task:
...
```
"""
pass

@classmethod
def name(cls) -> str:
"unknown"
"""A human readable name for this task."""

@dataclass
class PromptToImage(Task):
pass
@classmethod
def name(cls):
return "prompt to image"

@dataclass
class ImageToImage(Task):
image: NDArray
strength: float
fit: bool

@classmethod
def name(cls):
return "image to image"

@dataclass
class Inpaint(Task):
class Inpaint(ImageToImage):
class MaskSource(IntEnum):
ALPHA = 0
PROMPT = 1

image: NDArray
strength: float
fit: bool
mask_source: MaskSource
mask_prompt: str
confidence: float

@classmethod
def name(cls):
return "inpainting"

@dataclass
class DepthToImage(Task):
depth: NDArray | None
image: NDArray | None
strength: float

@classmethod
def name(cls):
return "depth to image"

@dataclass
class Outpaint(Task):
image: NDArray
origin: Tuple[int, int]
origin: Tuple[int, int]

@classmethod
def name(cls):
return "outpainting"
Loading

0 comments on commit c8ef88a

Please sign in to comment.