Skip to content

Commit

Permalink
remove attention mask for self-attention
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Oct 28, 2024
1 parent 2065adc commit 9214f4a
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions src/diffusers/pipelines/allegro/pipeline_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,8 @@ def __call__(
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
Expand Down Expand Up @@ -884,17 +886,9 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])

if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d

# prepare attention_mask.
# b c t h w -> b t h w
attention_mask = torch.ones_like(latent_model_input)[:, 0]

# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
attention_mask=attention_mask,
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=timestep,
Expand Down

0 comments on commit 9214f4a

Please sign in to comment.