Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Community Pipeline] MagicMix #1839

Merged
merged 7 commits into from
Dec 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |



Expand Down Expand Up @@ -815,6 +816,50 @@ plt.title('Stable Diffusion v1.4')
plt.axis('off')

plt.show()
```

As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints.

### Magic Mix

Implementation of the [MagicMix: Semantic Mixing with Diffusion Models](https://arxiv.org/abs/2210.16056) paper. This is a Diffusion Pipeline for semantic mixing of an image and a text prompt to create a new concept while preserving the spatial layout and geometry of the subject in the image. The pipeline takes an image that provides the layout semantics and a prompt that provides the content semantics for the mixing process.

There are 3 parameters for the method-
- `mix_factor`: It is the interpolation constant used in the layout generation phase. The greater the value of `mix_factor`, the greater the influence of the prompt on the layout generation process.
- `kmax` and `kmin`: These determine the range for the layout and content generation process. A higher value of kmax results in loss of more information about the layout of the original image and a higher value of kmin results in more steps for content generation process.

Here is an example usage-

```python
from diffusers import DiffusionPipeline, DDIMScheduler
from PIL import Image

pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
custom_pipeline="magic_mix",
Comment on lines +837 to +839
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, maybe load the pipeline in fp16, by passing the torch_dtype argument, to make inference faster.

scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
).to('cuda')

img = Image.open('phone.jpg')
mix_img = pipe(
img,
prompt = 'bed',
kmin = 0.3,
kmax = 0.5,
mix_factor = 0.5,
)
mix_img.save('phone_bed_mix.jpg')
```
The `mix_img` is a PIL image that can be saved locally or displayed directly in a google colab. Generated image is a mix of the layout semantics of the given image and the content semantics of the prompt.

E.g. the above script generates the following image:

`phone.jpg`

![206903102-34e79b9f-9ed2-4fac-bb38-82871343c655](https://user-images.githubusercontent.com/59410571/209578593-141467c7-d831-4792-8b9a-b17dc5e47816.jpg)

`phone_bed_mix.jpg`

![206903104-913a671d-ef53-4ae4-919d-64c3059c8f67](https://user-images.githubusercontent.com/59410571/209578602-70f323fa-05b7-4dd6-b055-e40683e37914.jpg)

As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints.
For more example generations check out this [demo notebook](https://github.com/daspartho/MagicMix/blob/main/demo.ipynb).
152 changes: 152 additions & 0 deletions examples/community/magic_mix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from typing import Union

import torch

from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
LMSDiscreteScheduler,
PNDMScheduler,
UNet2DConditionModel,
)
from PIL import Image
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer


class MagicMixPipeline(DiffusionPipeline):
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
):
super().__init__()

self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)

# convert PIL image to latents
def encode(self, img):
with torch.no_grad():
latent = self.vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(self.device) * 2 - 1)
latent = 0.18215 * latent.latent_dist.sample()
return latent

# convert latents to PIL image
def decode(self, latent):
latent = (1 / 0.18215) * latent
with torch.no_grad():
img = self.vae.decode(latent).sample
img = (img / 2 + 0.5).clamp(0, 1)
img = img.detach().cpu().permute(0, 2, 3, 1).numpy()
img = (img * 255).round().astype("uint8")
return Image.fromarray(img[0])

# convert prompt into text embeddings, also unconditional embeddings
def prep_text(self, prompt):
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)

text_embedding = self.text_encoder(text_input.input_ids.to(self.device))[0]

uncond_input = self.tokenizer(
"",
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)

uncond_embedding = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

return torch.cat([uncond_embedding, text_embedding])

def __call__(
self,
img: Image.Image,
prompt: str,
kmin: float = 0.3,
kmax: float = 0.6,
mix_factor: float = 0.5,
seed: int = 42,
steps: int = 50,
guidance_scale: float = 7.5,
) -> Image.Image:
tmin = steps - int(kmin * steps)
tmax = steps - int(kmax * steps)

text_embeddings = self.prep_text(prompt)

self.scheduler.set_timesteps(steps)

width, height = img.size
encoded = self.encode(img)

torch.manual_seed(seed)
noise = torch.randn(
(1, self.unet.in_channels, height // 8, width // 8),
).to(self.device)

latents = self.scheduler.add_noise(
encoded,
noise,
timesteps=self.scheduler.timesteps[tmax],
)

input = torch.cat([latents] * 2)

input = self.scheduler.scale_model_input(input, self.scheduler.timesteps[tmax])

with torch.no_grad():
pred = self.unet(
input,
self.scheduler.timesteps[tmax],
encoder_hidden_states=text_embeddings,
).sample

pred_uncond, pred_text = pred.chunk(2)
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)

latents = self.scheduler.step(pred, self.scheduler.timesteps[tmax], latents).prev_sample

for i, t in enumerate(tqdm(self.scheduler.timesteps)):
if i > tmax:
if i < tmin: # layout generation phase
orig_latents = self.scheduler.add_noise(
encoded,
noise,
timesteps=t,
)

input = (mix_factor * latents) + (
1 - mix_factor
) * orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics
input = torch.cat([input] * 2)

else: # content generation phase
input = torch.cat([latents] * 2)

input = self.scheduler.scale_model_input(input, t)

with torch.no_grad():
pred = self.unet(
input,
t,
encoder_hidden_states=text_embeddings,
).sample

pred_uncond, pred_text = pred.chunk(2)
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)

latents = self.scheduler.step(pred, t, latents).prev_sample

return self.decode(latents)