Skip to content

Commit

Permalink
Reformat code and use decorators (#244)
Browse files Browse the repository at this point in the history
Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
  • Loading branch information
Warvito authored Feb 11, 2023
1 parent eaaefdc commit 57bf5c4
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __call__(

return prediction

@torch.no_grad()
def sample(
self,
input_noise: torch.Tensor,
Expand Down Expand Up @@ -89,10 +90,9 @@ def sample(
intermediates = []
for t in progress_bar:
# 1. predict noise model_output
with torch.no_grad():
model_output = diffusion_model(
image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning
)
model_output = diffusion_model(
image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning
)

# 2. compute previous image: x_t -> x_t-1
image, _ = scheduler.step(model_output, t, image)
Expand Down Expand Up @@ -310,6 +310,7 @@ def __call__(

return prediction

@torch.no_grad()
def sample(
self,
input_noise: torch.Tensor,
Expand Down Expand Up @@ -347,16 +348,12 @@ def sample(
else:
latent = outputs

with torch.no_grad():
image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor)
image = autoencoder_model.decode_stage_2_outputs(latent / self.scale_factor)

if save_intermediates:
intermediates = []
for latent_intermediate in latent_intermediates:
with torch.no_grad():
intermediates.append(
autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor)
)
intermediates.append(autoencoder_model.decode_stage_2_outputs(latent_intermediate / self.scale_factor))
return image, intermediates

else:
Expand Down

0 comments on commit 57bf5c4

Please sign in to comment.