diff --git a/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py b/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py index bd87ccf..b75ca5e 100644 --- a/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py +++ b/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py @@ -842,5 +842,16 @@ def forward( if cond_drop_prob > 0.: prob_keep_mask = prob_mask_like((batch, 1), 1. - cond_drop_prob, device = self.device) mask = mask & prob_keep_mask + + b, length, dim = text_embeds.shape + masked_lengths = mask.sum(dim=1).tolist() + + max_length = max(masked_lengths) + text_embeds_dropped = torch.full((b, max_length, dim), self.text_embed_pad_value, dtype=text_embeds.dtype, device=text_embeds.device) + for i in range(b): + text_embeds_dropped[i, :masked_lengths[i]] = text_embeds[i, mask[i]] + + mask = (text_embeds_dropped != self.text_embed_pad_value).any(dim = -1) + text_embeds = text_embeds_dropped return tuple(self.conditioners), TextCondReturn(text_embeds, mask)