Skip to content

Commit

Permalink
redesign prompt tag parsing to handle recursion intelligently, esp. f…
Browse files Browse the repository at this point in the history
…or 'random' tag

for #130
  • Loading branch information
mcmonkey4eva committed Oct 19, 2023
1 parent 7a4e931 commit 0046b7b
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 77 deletions.
2 changes: 1 addition & 1 deletion src/Accounts/User.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ string buildPathPart(string part)
return data;
}
string path = Settings.OutPathBuilder.Format;
path = StringConversionHelper.QuickSimpleTagFiller(path, "[", "]", buildPathPart, maxDepth: 1);
path = StringConversionHelper.QuickSimpleTagFiller(path, "[", "]", buildPathPart, false);
return Utilities.StrictFilenameClean(path);
}
}
2 changes: 1 addition & 1 deletion src/StableSwarmUI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="FreneticLLC.FreneticUtilities" Version="1.0.16" />
<PackageReference Include="FreneticLLC.FreneticUtilities" Version="1.0.17" />
<PackageReference Include="LiteDB" Version="5.0.17" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="SixLabors.ImageSharp" Version="3.0.2" />
Expand Down
168 changes: 93 additions & 75 deletions src/Text2Image/T2IParamInput.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,23 @@ public class PromptTagContext
public T2IParamInput Input;

public string Param;

public Func<string, string> EmbedFormatter;

public string[] Embeds, Loras;

public string Parse(string text)
{
return Input.ProcessPromptLike(text, this);
}
}

/// <summary>Mapping of prompt tag prefixes, to allow for registration of custom prompt tags.</summary>
public static Dictionary<string, Func<string, PromptTagContext, string>> PromptTagProcessors = new();

/// <summary>Mapping of prompt tag prefixes, to allow for registration of custom prompt tags - specifically post-processing like lora (which remove from prompt and get read elsewhere).</summary>
public static Dictionary<string, Func<string, PromptTagContext, string>> PromptTagPostProcessors = new();

static T2IParamInput()
{
PromptTagProcessors["random"] = (data, context) =>
Expand All @@ -34,11 +46,11 @@ static T2IParamInput()
{
return null;
}
return vals[context.Random.Next(vals.Length)];
return context.Parse(vals[context.Random.Next(vals.Length)]);
};
PromptTagProcessors["preset"] = (data, context) =>
{
T2IPreset preset = context.Input.SourceSession.User.GetPreset(data);
T2IPreset preset = context.Input.SourceSession.User.GetPreset(context.Parse(data));
if (preset is null)
{
return null;
Expand All @@ -50,7 +62,50 @@ static T2IParamInput()
}
return "";
};
// TODO: Wildcards
PromptTagProcessors["embed"] = (data, context) =>
{
data = context.Parse(data);
if (context.Embeds is not null)
{
string want = data.ToLowerFast().Replace('\\', '/');
string matched = context.Embeds.FirstOrDefault(e => e.ToLowerFast().StartsWith(want)) ?? context.Embeds.FirstOrDefault(e => e.ToLowerFast().Contains(want));
if (matched is not null)
{
data = matched;
}
}
return context.EmbedFormatter(data.Replace('/', Path.DirectorySeparatorChar));
};
PromptTagProcessors["embedding"] = PromptTagProcessors["embed"];
PromptTagPostProcessors["lora"] = (data, context) =>
{
data = context.Parse(data);
string lora = data.ToLowerFast().Replace('\\', '/');
int colonIndex = lora.IndexOf(':');
double strength = 1;
if (colonIndex != -1 && double.TryParse(lora[(colonIndex + 1)..], out strength))
{
lora = lora[..colonIndex];
}
string matched = context.Loras.FirstOrDefault(e => e.ToLowerFast().StartsWith(lora)) ?? context.Loras.FirstOrDefault(e => e.ToLowerFast().Contains(lora));
if (matched is not null)
{
List<string> loraList = context.Input.Get(T2IParamTypes.Loras);
List<string> weights = context.Input.Get(T2IParamTypes.LoraWeights);
if (loraList is null)
{
loraList = new();
weights = new();
}
loraList.Add(matched);
weights.Add(strength.ToString());
context.Input.Set(T2IParamTypes.Loras, loraList);
context.Input.Set(T2IParamTypes.LoraWeights, weights);
return "";
}
return null;
};
// TODO: Wildcards (random by user-editable listing files)
}

/// <summary>The raw values in this input. Do not use this directly, instead prefer:
Expand Down Expand Up @@ -190,89 +245,52 @@ public string ProcessPromptLike(T2IRegisteredParam<string> param, Func<string, s
string lowRef = val.ToLowerFast();
string[] embeds = lowRef.Contains("<embed") ? Program.T2IModelSets["Embedding"].ListModelsFor(SourceSession).Select(m => m.Name).ToArray() : null;
string[] loras = lowRef.Contains("<lora:") ? Program.T2IModelSets["LoRA"].ListModelsFor(SourceSession).Select(m => m.Name.ToLowerFast()).ToArray() : null;
PromptTagContext context = new() { Input = this, Random = rand, Param = param.Type.ID };
PromptTagContext context = new() { Input = this, Random = rand, Param = param.Type.ID, EmbedFormatter = embedFormatter, Embeds = embeds, Loras = loras };
string fixedVal = ProcessPromptLike(val, context);
if (fixedVal != val)
{
ExtraMeta[$"original_{param.Type.ID}"] = val;
}
return fixedVal;
}

/// <summary>Special utility to process prompt inputs before the request is executed (to parse wildcards, embeddings, etc).</summary>
public string ProcessPromptLike(string val, PromptTagContext context)
{
if (val is null)
{
return null;
}
string addBefore = "", addAfter = "";
string fixedVal = StringConversionHelper.QuickSimpleTagFiller(val, "<", ">", tag =>
void processSet(Dictionary<string, Func<string, PromptTagContext, string>> set)
{
(string prefix, string data) = tag.BeforeAndAfter(':');
if (string.IsNullOrWhiteSpace(data))
{
return $"<{tag}>";
}
switch (prefix.ToLowerFast())
val = StringConversionHelper.QuickSimpleTagFiller(val, "<", ">", tag =>
{
case "embed":
case "embedding":
{
if (embeds is not null)
{
string want = data.ToLowerFast().Replace('\\', '/');
string matched = embeds.FirstOrDefault(e => e.ToLowerFast().StartsWith(want)) ?? embeds.FirstOrDefault(e => e.ToLowerFast().Contains(want));
if (matched is not null)
{
data = matched;
}
}
return embedFormatter(data.Replace('/', Path.DirectorySeparatorChar));
}
case "lora":
(string prefix, string data) = tag.BeforeAndAfter(':');
if (!string.IsNullOrWhiteSpace(data) && set.TryGetValue(prefix, out Func<string, PromptTagContext, string> proc))
{
string result = proc(data, context);
if (result is not null)
{
string lora = data.ToLowerFast().Replace('\\', '/');
int colonIndex = lora.IndexOf(':');
double strength = 1;
if (colonIndex != -1 && double.TryParse(lora[(colonIndex + 1)..], out strength))
if (result.StartsWithNull()) // Special case for preset tag modifying the current value
{
lora = lora[..colonIndex];
}
string matched = loras.FirstOrDefault(e => e.ToLowerFast().StartsWith(lora)) ?? loras.FirstOrDefault(e => e.ToLowerFast().Contains(lora));
if (matched is not null)
{
List<string> loraList = Get(T2IParamTypes.Loras);
List<string> weights = Get(T2IParamTypes.LoraWeights);
if (loraList is null)
result = result[1..];
if (result.Contains("{value}"))
{
loraList = new();
weights = new();
addBefore += result.Before("{value}");
}
loraList.Add(matched);
weights.Add(strength.ToString());
Set(T2IParamTypes.Loras, loraList);
Set(T2IParamTypes.LoraWeights, weights);
addAfter += result.After("{value}");
return "";
}
else
{
return $"<{tag}>";
}
}
default:
if (PromptTagProcessors.TryGetValue(prefix, out Func<string, PromptTagContext, string> proc))
{
string result = proc(data, context);
if (result is not null)
{
if (result.StartsWithNull()) // Special case for preset tag modifying the current value
{
result = result[1..];
if (result.Contains("{value}"))
{
addBefore += result.Before("{value}");
}
addAfter += result.After("{value}");
return "";
}
return result;
}
return result;
}
return $"<{tag}>";
}
});
fixedVal = addBefore + fixedVal + addAfter;
if (fixedVal != val)
{
ExtraMeta[$"original_{param.Type.ID}"] = val;
}
return $"<{tag}>";
}, false, 0);
}
return fixedVal;
processSet(PromptTagProcessors);
processSet(PromptTagPostProcessors);
return addBefore + val + addAfter;
}

/// <summary>Gets the raw value of the parameter, if it is present, or null if not.</summary>
Expand Down

0 comments on commit 0046b7b

Please sign in to comment.