Skip to content

Commit

Permalink
fix(range): min range not work when max is unset
Browse files Browse the repository at this point in the history
  • Loading branch information
bluelovers committed Jul 19, 2024
1 parent 58815af commit 7e1e064
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 7 deletions.
16 changes: 11 additions & 5 deletions src/dynamicprompts/commands/variant_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dynamicprompts.commands.base import Command
from dynamicprompts.commands.literal_command import LiteralCommand
from dynamicprompts.enums import SamplingMethod
from dynamicprompts.utils import _fix_max_bound

logger = logging.getLogger(__name__)

Expand All @@ -21,15 +22,18 @@ class VariantOption:
class VariantCommand(Command):
variants: list[VariantOption]
min_bound: int = 1
max_bound: int = 1
max_bound: int = None
separator: str = ","
sampling_method: SamplingMethod | None = None

def __post_init__(self):
min_bound, max_bound = sorted((self.min_bound, self.max_bound))
min_bound = self.min_bound
if self.max_bound:
min_bound, max_bound = sorted((self.min_bound, self.max_bound))
object.__setattr__(self, "max_bound", max_bound)
min_bound = max(0, min_bound)
object.__setattr__(self, "min_bound", min_bound)
object.__setattr__(self, "max_bound", max_bound)


def __len__(self) -> int:
return len(self.variants)
Expand All @@ -49,8 +53,10 @@ def values(self) -> list[Command]:
return [p.value for p in self.variants]

def adjust_range(self) -> VariantCommand:
min_bound = min(self.min_bound, len(self.values))
max_bound = min(self.max_bound, len(self.values))
max_options = len(self.variants)
min_bound = min(self.min_bound, max_options)
# max_bound = min(self.max_bound, max_options)
max_bound = _fix_max_bound(self.max_bound, max_options)
return dataclasses.replace(self, min_bound=min_bound, max_bound=max_bound)

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions src/dynamicprompts/parser/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ def _parse_bound_expr(expr, max_options):
lbound = int(expr["range"]["lower"])
if "upper" in expr["range"]:
ubound = int(expr["range"]["upper"])
else:
ubound = None

if "separator" in expr:
separator = expr["separator"][0]
Expand Down
7 changes: 5 additions & 2 deletions src/dynamicprompts/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dynamicprompts.sampling_context import SamplingContext
from dynamicprompts.sampling_result import SamplingResult
from dynamicprompts.types import ResultGen
from dynamicprompts.utils import _fix_max_bound

logger = logging.getLogger(__name__)

Expand All @@ -26,8 +27,10 @@ def wildcard_to_variant(
) -> VariantCommand:
wildcard = next(iter(context.sample_prompts(command.wildcard, 1))).text
values = context.wildcard_manager.get_values(wildcard)
min_bound = min(min_bound, len(values))
max_bound = min(max_bound, len(values))
max_options = len(values)
min_bound = min(min_bound, max_options)
# max_bound = min(max_bound, len(values))
max_bound = _fix_max_bound(max_bound, max_options)

variant_options = [
VariantOption(parse(v, parser_config=context.parser_config))
Expand Down
5 changes: 5 additions & 0 deletions src/dynamicprompts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,8 @@ def choose_without_replacement(
weights.remove(weights[values.index(chosen)])
values.remove(chosen)
return chosen_values


def _fix_max_bound(max_bound: int | None, values: any | int) -> int:
max_options = values if isinstance(values, int) else len(values)
return min(max_bound, max_options) if max_bound else max_options
2 changes: 2 additions & 0 deletions tests/parser/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ def test_variant_with_weights(self, input, weights):
), # https://github.com/adieyal/sd-dynamic-prompts/issues/223
("{!1-2$$cat|dog|bird}", 1, 2),
("{~1-2$$cat|dog|bird}", 1, 2),
("{2-$$cat|dog|bird}", 2, 3),
("{0-$$cat|dog|bird}", 0, 3),
],
)
def test_range(self, input, min_bound, max_bound):
Expand Down
47 changes: 47 additions & 0 deletions tests/samplers/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,53 @@ def test_nested_wildcard_with_range_and_literal(

assert [str(p) for p in prompts] == expected

@pytest.mark.parametrize(
("sampling_context", "key"),
[
# (lazy_fixture("random_sampling_context"), "shuffled_colours"), # TODO - fix this
(lazy_fixture("cyclical_sampling_context"), "wildcard_colours"),
(lazy_fixture("combinatorial_sampling_context"), "wildcard_colours"),
],
)
def test_nested_wildcard_with_range_min_without_max_and_literal(
self,
sampling_context: SamplingContext,
key: str,
data_lookups: dict[str, list[str]],
):
sampler = sampling_context.default_sampler

template = "{2-$$__colors*__|black}"
expected = data_lookups[key]

if isinstance(sampler, RandomSampler):
variant_choices = [[LiteralCommand("black")], [WildcardCommand("colors*")]]

with patch_random_sampler_wildcard_choice(expected):
with patch_random_sampler_variant_choices(variant_choices):
black = ["black"] * len(expected)
arr1 = zipstr(expected, black, sep=",")
arr2 = zipstr(black, expected, sep=",")
expected = interleave(arr1, arr2)

prompts = list(
sampling_context.sample_prompts(template, len(expected)),
)
else:
if isinstance(sampler, CyclicalSampler):
black = ["black"] * len(expected)
arr1 = zipstr(expected, black, sep=",")
arr2 = zipstr(black, expected, sep=",")
expected = interleave(arr1, arr2)
elif isinstance(sampler, CombinatorialSampler):
expected = [f"{e},black" for e in expected] + [
f"black,{e}" for e in expected
]

prompts = sampling_context.sample_prompts(template, len(expected))

assert [str(p) for p in prompts] == expected

@pytest.mark.parametrize(
("sampling_context"),
[
Expand Down

0 comments on commit 7e1e064

Please sign in to comment.