Skip to content

Commit

Permalink
fix: crashing when lora trigger is used (#120)
Browse files Browse the repository at this point in the history
feat: adding generation metadata to the save
fix: avoid clobbering batched images on save
  • Loading branch information
db0 authored Jan 16, 2024
1 parent 6eaebca commit 4b38e0c
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 7 deletions.
3 changes: 2 additions & 1 deletion StableHordeClient.gd
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ func _get_test_images(n = 10) -> Array:
"none",
img,
'Test Image ID',
"Test Request ID")
"Test Request ID",
[])
new_texture.create_from_image(img)
test_array.append(new_texture)
return(test_array)
Expand Down
18 changes: 16 additions & 2 deletions addons/stable_horde_client/AIImageTexture.gd
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
class_name AIImageTexture
extends ImageTexture

const FILENAME_TEMPLATE := "{timestamp}_{gen_seed}"
const FILENAME_TEMPLATE := "{timestamp}_{gen_seed}{batch_id}"
const DIRECTORY_TEMPLATE := "{sampler_name}_{steps}_{prompt}"

# The prompt which generated this image
Expand Down Expand Up @@ -32,6 +32,7 @@ var timestamp: float
var image_horde_id: String
var control_type: String
var request_id: String
var gen_metadata: Array

func _init(
_prompt: String,
Expand All @@ -44,7 +45,9 @@ func _init(
_control_type: String,
_image: Image,
_image_horde_id: String,
_request_id: String) -> void:
_request_id: String,
_gen_metadata: Array
) -> void:
._init()
prompt = _prompt
attributes = _imgen_params.duplicate(true)
Expand All @@ -68,6 +71,8 @@ func _init(
timestamp = _timestamp
request_id = _request_id
attributes['request_id'] = _request_id
gen_metadata = _gen_metadata
attributes['gen_metadata'] = _gen_metadata

# This can be used to provide metadata for the source image in img2img requests
func set_source_image_path(image_path: String) -> void:
Expand All @@ -78,7 +83,10 @@ func get_filename() -> String:
var fmt := {
"timestamp": timestamp,
"gen_seed": gen_seed,
"batch_id": '',
}
if _get_batch_id() != '':
fmt["batch_id"] = "_" + _get_batch_id()
var filename = sanitize_filename(FILENAME_TEMPLATE.format(fmt)).substr(0,100)
return(filename)

Expand Down Expand Up @@ -155,3 +163,9 @@ static func sanitize_filename(filename: String) -> String:
for c in replace_chars:
filename = filename.replace(c,'_')
return(filename)

func _get_batch_id() -> String:
for meta in gen_metadata:
if meta["type"] == "batch_index":
return meta["ref"]
return ""
4 changes: 3 additions & 1 deletion addons/stable_horde_client/stable_horde_client.gd
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,9 @@ func prepare_aitexture(imgbuffer: PoolByteArray, img_dict: Dictionary, timestamp
control_type,
image,
img_dict["id"],
async_request_id)
async_request_id,
img_dict.get("gen_metadata", [])
)
texture.create_from_image(image)
latest_image_textures.append(texture)
# Avoid keeping all images in RAM. Until I find a reason for it.
Expand Down
3 changes: 2 additions & 1 deletion src/Lora/Lora.gd
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ func update_selected_loras_label() -> void:

func _on_lora_trigger_pressed(index: int) -> void:
var version_id: String = selected_loras_list[index]["id"]
var lora_reference := lora_reference_node.get_lora_info(version_id)
var is_version: bool = selected_loras_list[index]["is_version"]
var lora_reference := lora_reference_node.get_lora_info(version_id, is_version)
var selected_triggers: Array = []
if lora_reference['versions'][version_id]['triggers'].size() == 1:
selected_triggers = [lora_reference['versions'][version_id]['triggers'][0]]
Expand Down
2 changes: 0 additions & 2 deletions src/ParamBus.gd
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,8 @@ func _on_listnode_changed(_thing_list: Array, thing_node: Node) -> void:
emit_signal("params_changed")

func is_lcm_payload() -> bool:
print_debug('www')
if loras_node.has_lcm_loras():
return true
print_debug(sampler_name_node.get_item_text(sampler_name_node.selected))
if sampler_name_node.get_item_text(sampler_name_node.selected) == 'lcm':
return true
return false
Expand Down

0 comments on commit 4b38e0c

Please sign in to comment.