Skip to content

Commit

Permalink
fix(x_attn_kwargs): only pass to pipeline if set
Browse files Browse the repository at this point in the history
  • Loading branch information
gadicc committed Aug 3, 2023
1 parent 5f46faa commit 3f1f980
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def sendStatus():
)
# downloaded_models.update({normalized_model_id: True})
clearPipelines()
cross_attention_kwargs = None
if model:
model.to("cpu") # Necessary to avoid a memory leak
await send(
Expand Down Expand Up @@ -287,6 +288,7 @@ def sendStatus():
if MODEL_ID == "ALL":
if last_model_id != normalized_model_id:
clearPipelines()
cross_attention_kwargs = None
model = loadModel(normalized_model_id, send_opts=send_opts)
last_model_id = normalized_model_id
else:
Expand Down Expand Up @@ -447,8 +449,12 @@ def sendStatus():
if mi_cross_attention_kwargs:
model_inputs.pop("cross_attention_kwargs")
if isinstance(mi_cross_attention_kwargs, str):
if not cross_attention_kwargs:
cross_attention_kwargs = {}
cross_attention_kwargs.update(json.loads(mi_cross_attention_kwargs))
elif type(mi_cross_attention_kwargs) == dict:
if not cross_attention_kwargs:
cross_attention_kwargs = {}
cross_attention_kwargs.update(mi_cross_attention_kwargs)
else:
return {
Expand All @@ -459,6 +465,8 @@ def sendStatus():
}

print({"cross_attention_kwargs": cross_attention_kwargs})
if cross_attention_kwargs:
model_inputs.update({"cross_attention_kwargs": cross_attention_kwargs})

# Parse out your arguments
# prompt = model_inputs.get("prompt", None)
Expand Down Expand Up @@ -595,14 +603,11 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
or isinstance(model, StableDiffusionXLImg2ImgPipeline)
or isinstance(model, StableDiffusionXLInpaintPipeline)
)
print("is_sdxl", is_sdxl)

with torch.inference_mode():
custom_pipeline_method = call_inputs.get("custom_pipeline_method", None)
print(
pipeline,
{
"cross_attention_kwargs": cross_attention_kwargs,
"callback": callback,
"**model_inputs": model_inputs,
},
Expand All @@ -616,7 +621,6 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
getattr(pipeline, custom_pipeline_method)
if custom_pipeline_method
else pipeline,
cross_attention_kwargs=cross_attention_kwargs,
callback=callback,
**model_inputs,
)
Expand Down

0 comments on commit 3f1f980

Please sign in to comment.