Skip to content

Commit

Permalink
Added support for Stable Diffusion 2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Hila Chefer committed Mar 1, 2023
1 parent 5f10e15 commit 15c30b1
Show file tree
Hide file tree
Showing 4 changed files with 494 additions and 125 deletions.
2 changes: 2 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
class RunConfig:
# Guiding text prompt
prompt: str
# Whether to use Stable Diffusion v2.1
sd_2_1: bool = False
# Which token indices to alter with attend-and-excite
token_indices: List[int] = None
# Which random seeds to use when generating
Expand Down
568 changes: 456 additions & 112 deletions notebooks/generate_images.ipynb

Large diffs are not rendered by default.

39 changes: 28 additions & 11 deletions pipeline_attend_and_excite.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,21 @@ def _encode_prompt(

return text_inputs, prompt_embeds

@staticmethod
def _compute_max_attention_per_index(attention_maps: torch.Tensor,
def _compute_max_attention_per_index(self,
attention_maps: torch.Tensor,
indices_to_alter: List[int],
smooth_attentions: bool = False,
sigma: float = 0.5,
kernel_size: int = 3) -> List[torch.Tensor]:
kernel_size: int = 3,
normalize_eot: bool = False) -> List[torch.Tensor]:
""" Computes the maximum attention value for each of the tokens we wish to alter. """
attention_for_text = attention_maps[:, :, 1:-1]
last_idx = -1
if normalize_eot:
prompt = self.prompt
if isinstance(self.prompt, list):
prompt = self.prompt[0]
last_idx = len(self.tokenizer(prompt)['input_ids']) - 1
attention_for_text = attention_maps[:, :, 1:last_idx]
attention_for_text *= 100
attention_for_text = torch.nn.functional.softmax(attention_for_text, dim=-1)

Expand All @@ -218,7 +225,8 @@ def _aggregate_and_get_max_attention_per_token(self, attention_store: AttentionS
attention_res: int = 16,
smooth_attentions: bool = False,
sigma: float = 0.5,
kernel_size: int = 3):
kernel_size: int = 3,
normalize_eot: bool = False):
""" Aggregates the attention for each token and computes the max activation value for each token to alter. """
attention_maps = aggregate_attention(
attention_store=attention_store,
Expand All @@ -231,7 +239,8 @@ def _aggregate_and_get_max_attention_per_token(self, attention_store: AttentionS
indices_to_alter=indices_to_alter,
smooth_attentions=smooth_attentions,
sigma=sigma,
kernel_size=kernel_size)
kernel_size=kernel_size,
normalize_eot=normalize_eot)
return max_attention_per_index

@staticmethod
Expand Down Expand Up @@ -265,7 +274,8 @@ def _perform_iterative_refinement_step(self,
smooth_attentions: bool = True,
sigma: float = 0.5,
kernel_size: int = 3,
max_refinement_steps: int = 20):
max_refinement_steps: int = 20,
normalize_eot: bool = False):
"""
Performs the iterative latent refinement introduced in the paper. Here, we continuously update the latent
code according to our loss objective until the given threshold is reached for all tokens.
Expand All @@ -286,7 +296,9 @@ def _perform_iterative_refinement_step(self,
attention_res=attention_res,
smooth_attentions=smooth_attentions,
sigma=sigma,
kernel_size=kernel_size)
kernel_size=kernel_size,
normalize_eot=normalize_eot
)

loss, losses = self._compute_loss(max_attention_per_index, return_losses=True)

Expand Down Expand Up @@ -324,7 +336,8 @@ def _perform_iterative_refinement_step(self,
attention_res=attention_res,
smooth_attentions=smooth_attentions,
sigma=sigma,
kernel_size=kernel_size)
kernel_size=kernel_size,
normalize_eot=normalize_eot)
loss, losses = self._compute_loss(max_attention_per_index, return_losses=True)
print(f"\t Finished with loss of: {loss}")
return loss, latents, max_attention_per_index
Expand Down Expand Up @@ -360,6 +373,7 @@ def __call__(
smooth_attentions: bool = True,
sigma: float = 0.5,
kernel_size: int = 3,
sd_2_1: bool = False,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -438,6 +452,7 @@ def __call__(
)

# 2. Define call parameters
self.prompt = prompt
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
Expand Down Expand Up @@ -508,7 +523,8 @@ def __call__(
attention_res=attention_res,
smooth_attentions=smooth_attentions,
sigma=sigma,
kernel_size=kernel_size)
kernel_size=kernel_size,
normalize_eot=sd_2_1)

if not run_standard_sd:

Expand All @@ -531,7 +547,8 @@ def __call__(
attention_res=attention_res,
smooth_attentions=smooth_attentions,
sigma=sigma,
kernel_size=kernel_size)
kernel_size=kernel_size,
normalize_eot=sd_2_1)

# Perform gradient update
if i < max_iter_to_alter:
Expand Down
10 changes: 8 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

def load_model(config: RunConfig):
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
stable = AttendAndExcitePipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)

if config.sd_2_1:
stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base"
else:
stable_diffusion_version = "CompVis/stable-diffusion-v1-4"
stable = AttendAndExcitePipeline.from_pretrained(stable_diffusion_version).to(device)
return stable


Expand Down Expand Up @@ -54,7 +59,8 @@ def run_on_prompt(prompt: List[str],
scale_range=config.scale_range,
smooth_attentions=config.smooth_attentions,
sigma=config.sigma,
kernel_size=config.kernel_size)
kernel_size=config.kernel_size,
sd_2_1=config.sd_2_1)
image = outputs.images[0]
return image

Expand Down

0 comments on commit 15c30b1

Please sign in to comment.