-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Community Pipeline] MagicMix #1839
Merged
Merged
Changes from 5 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
87732e2
initial
daspartho d930ced
type hints
daspartho 6847d59
update scheduler type hint
daspartho a3095f1
add to README
daspartho cbf4ca5
add example generation to README
daspartho 28e1a3c
v -> mix_factor
daspartho 573dec8
load scheduler from pretrained
daspartho File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ If a community doesn't work as expected, please open an issue and ping the autho | |
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | | ||
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | | ||
Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) | | ||
MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) | | ||
|
||
|
||
|
||
|
@@ -815,6 +816,50 @@ plt.title('Stable Diffusion v1.4') | |
plt.axis('off') | ||
|
||
plt.show() | ||
``` | ||
|
||
As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints. | ||
|
||
### Magic Mix | ||
|
||
Implementation of the [MagicMix: Semantic Mixing with Diffusion Models](https://arxiv.org/abs/2210.16056) paper. This is a Diffusion Pipeline for semantic mixing of an image and a text prompt to create a new concept while preserving the spatial layout and geometry of the subject in the image. The pipeline takes an image that provides the layout semantics and a prompt that provides the content semantics for the mixing process. | ||
|
||
There are 3 parameters for the method- | ||
- `v`: It is the interpolation constant used in the layout generation phase. The greater the value of v, the greater the influence of the prompt on the layout generation process. | ||
- `kmax` and `kmin`: These determine the range for the layout and content generation process. A higher value of kmax results in loss of more information about the layout of the original image and a higher value of kmin results in more steps for content generation process. | ||
|
||
Here is an example usage- | ||
|
||
```python | ||
from diffusers import DiffusionPipeline, DDIMScheduler | ||
from PIL import Image | ||
|
||
pipe = DiffusionPipeline.from_pretrained( | ||
"CompVis/stable-diffusion-v1-4", | ||
custom_pipeline="magic_mix", | ||
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The scheduler can be loaded using DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") |
||
).to('cuda') | ||
|
||
img = Image.open('phone.jpg') | ||
mix_img = pipe( | ||
img, | ||
prompt = 'bed', | ||
kmin = 0.3, | ||
kmax = 0.5, | ||
v = 0.5, | ||
) | ||
mix_img.save('phone_bed_mix.jpg') | ||
``` | ||
The `mix_img` is a PIL image that can be saved locally or displayed directly in a google colab. Generated image is a mix of the layout semantics of the given image and the content semantics of the prompt. | ||
|
||
E.g. the above script generates the following image: | ||
|
||
`phone.jpg` | ||
|
||
![206903102-34e79b9f-9ed2-4fac-bb38-82871343c655](https://user-images.githubusercontent.com/59410571/209578593-141467c7-d831-4792-8b9a-b17dc5e47816.jpg) | ||
|
||
`phone_bed_mix.jpg` | ||
|
||
![206903104-913a671d-ef53-4ae4-919d-64c3059c8f67](https://user-images.githubusercontent.com/59410571/209578602-70f323fa-05b7-4dd6-b055-e40683e37914.jpg) | ||
|
||
As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints. | ||
For more example generations check out this [demo notebook](https://github.com/daspartho/MagicMix/blob/main/demo.ipynb). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
from typing import Union | ||
|
||
import torch | ||
|
||
from diffusers import ( | ||
AutoencoderKL, | ||
DDIMScheduler, | ||
DiffusionPipeline, | ||
LMSDiscreteScheduler, | ||
PNDMScheduler, | ||
UNet2DConditionModel, | ||
) | ||
from PIL import Image | ||
from torchvision import transforms as tfms | ||
from tqdm.auto import tqdm | ||
from transformers import CLIPTextModel, CLIPTokenizer | ||
|
||
|
||
class MagicMixPipeline(DiffusionPipeline): | ||
def __init__( | ||
self, | ||
vae: AutoencoderKL, | ||
text_encoder: CLIPTextModel, | ||
tokenizer: CLIPTokenizer, | ||
unet: UNet2DConditionModel, | ||
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler], | ||
): | ||
super().__init__() | ||
|
||
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) | ||
|
||
# convert PIL image to latents | ||
def encode(self, img): | ||
with torch.no_grad(): | ||
latent = self.vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(self.device) * 2 - 1) | ||
latent = 0.18215 * latent.latent_dist.sample() | ||
return latent | ||
|
||
# convert latents to PIL image | ||
def decode(self, latent): | ||
latent = (1 / 0.18215) * latent | ||
with torch.no_grad(): | ||
img = self.vae.decode(latent).sample | ||
img = (img / 2 + 0.5).clamp(0, 1) | ||
img = img.detach().cpu().permute(0, 2, 3, 1).numpy() | ||
img = (img * 255).round().astype("uint8") | ||
return Image.fromarray(img[0]) | ||
|
||
# convert prompt into text embeddings, also unconditional embeddings | ||
def prep_text(self, prompt): | ||
text_input = self.tokenizer( | ||
prompt, | ||
padding="max_length", | ||
max_length=self.tokenizer.model_max_length, | ||
truncation=True, | ||
return_tensors="pt", | ||
) | ||
|
||
text_embedding = self.text_encoder(text_input.input_ids.to(self.device))[0] | ||
|
||
uncond_input = self.tokenizer( | ||
"", | ||
padding="max_length", | ||
max_length=self.tokenizer.model_max_length, | ||
truncation=True, | ||
return_tensors="pt", | ||
) | ||
|
||
uncond_embedding = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | ||
|
||
return torch.cat([uncond_embedding, text_embedding]) | ||
|
||
def __call__( | ||
self, | ||
img: Image.Image, | ||
prompt: str, | ||
kmin: float = 0.3, | ||
kmax: float = 0.6, | ||
v: float = 0.5, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we use a more descriptive name for this argument? One letter variables aren't informative |
||
seed: int = 42, | ||
steps: int = 50, | ||
guidance_scale: float = 7.5, | ||
) -> Image.Image: | ||
tmin = steps - int(kmin * steps) | ||
tmax = steps - int(kmax * steps) | ||
|
||
text_embeddings = self.prep_text(prompt) | ||
|
||
self.scheduler.set_timesteps(steps) | ||
|
||
width, height = img.size | ||
encoded = self.encode(img) | ||
|
||
torch.manual_seed(seed) | ||
noise = torch.randn( | ||
(1, self.unet.in_channels, height // 8, width // 8), | ||
).to(self.device) | ||
|
||
latents = self.scheduler.add_noise( | ||
encoded, | ||
noise, | ||
timesteps=self.scheduler.timesteps[tmax], | ||
) | ||
|
||
input = torch.cat([latents] * 2) | ||
|
||
input = self.scheduler.scale_model_input(input, self.scheduler.timesteps[tmax]) | ||
|
||
with torch.no_grad(): | ||
pred = self.unet( | ||
input, | ||
self.scheduler.timesteps[tmax], | ||
encoder_hidden_states=text_embeddings, | ||
).sample | ||
|
||
pred_uncond, pred_text = pred.chunk(2) | ||
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) | ||
|
||
latents = self.scheduler.step(pred, self.scheduler.timesteps[tmax], latents).prev_sample | ||
|
||
for i, t in enumerate(tqdm(self.scheduler.timesteps)): | ||
if i > tmax: | ||
if i < tmin: # layout generation phase | ||
orig_latents = self.scheduler.add_noise( | ||
encoded, | ||
noise, | ||
timesteps=t, | ||
) | ||
|
||
input = (v * latents) + ( | ||
1 - v | ||
) * orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics | ||
input = torch.cat([input] * 2) | ||
|
||
else: # content generation phase | ||
input = torch.cat([latents] * 2) | ||
|
||
input = self.scheduler.scale_model_input(input, t) | ||
|
||
with torch.no_grad(): | ||
pred = self.unet( | ||
input, | ||
t, | ||
encoder_hidden_states=text_embeddings, | ||
).sample | ||
|
||
pred_uncond, pred_text = pred.chunk(2) | ||
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) | ||
|
||
latents = self.scheduler.step(pred, t, latents).prev_sample | ||
|
||
return self.decode(latents) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, maybe load the pipeline in
fp16
, by passing thetorch_dtype
argument, to make inference faster.