-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
150 - Diff-scm #306
150 - Diff-scm #306
Changes from 13 commits
6a3efe0
4d9c30d
a750e53
cc7b704
2a0bed9
20535f0
04dff72
ff75429
f496f70
b8a2a48
c25950e
29c0699
f50f545
472b1f7
5940377
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1889,4 +1889,4 @@ def forward( | |
# 7. output block | ||
h = self.out(h) | ||
|
||
return h | ||
return h | ||
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -225,7 +225,83 @@ def step( | |||||||||
|
||||||||||
return pred_prev_sample, pred_original_sample | ||||||||||
|
||||||||||
def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: | ||||||||||
def reversed_step( | ||||||||||
self, | ||||||||||
model_output: torch.Tensor, | ||||||||||
timestep: int, | ||||||||||
sample: torch.Tensor, | ||||||||||
eta: float = 0.0, | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove eta since it is not used in reversed_step |
||||||||||
) -> tuple[torch.Tensor, torch.Tensor]: | ||||||||||
""" | ||||||||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | ||||||||||
process from the learned model outputs (most often the predicted noise). | ||||||||||
|
||||||||||
Args: | ||||||||||
model_output: direct output from learned diffusion model. | ||||||||||
timestep: current discrete timestep in the diffusion chain. | ||||||||||
sample: current instance of sample being created by diffusion process. | ||||||||||
eta: weight of noise for added noise in diffusion step. | ||||||||||
predict_epsilon: flag to use when model predicts the samples directly instead of the noise, epsilon. | ||||||||||
generator: random number generator. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Arguments not declared |
||||||||||
|
||||||||||
Returns: | ||||||||||
pred_prev_sample: Predicted previous sample | ||||||||||
pred_original_sample: Predicted original sample | ||||||||||
""" | ||||||||||
# See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf | ||||||||||
|
||||||||||
# Notation (<variable name> -> <name in paper> | ||||||||||
# - model_output -> e_theta(x_t, t) | ||||||||||
# - pred_original_sample -> f_theta(x_t, t) or x_0 | ||||||||||
# - std_dev_t -> sigma_t | ||||||||||
# - eta -> η | ||||||||||
# - pred_sample_direction -> "direction pointing to x_t" | ||||||||||
# - pred_post_sample -> "x_t+1" | ||||||||||
|
||||||||||
assert eta == 0, "eta must be 0 for reversed_step" | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove eta since it is not used in the reversed_step |
||||||||||
|
||||||||||
# 1. get previous step value (=t-1) | ||||||||||
prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps # t+1 | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
# 2. compute alphas, betas | ||||||||||
alpha_prod_t = self.alphas_cumprod[timestep] | ||||||||||
alpha_prod_t_prev = ( | ||||||||||
self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod | ||||||||||
) # alpha at timestep t+1 | ||||||||||
|
||||||||||
beta_prod_t = 1 - alpha_prod_t | ||||||||||
|
||||||||||
# 3. compute predicted original sample from predicted noise also called | ||||||||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||||||||||
if self.prediction_type == "epsilon": | ||||||||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | ||||||||||
elif self.prediction_type == "sample": | ||||||||||
pred_original_sample = model_output | ||||||||||
elif self.prediction_type == "v_prediction": | ||||||||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output | ||||||||||
# predict V | ||||||||||
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adopt new variable names used for DDIM.step (with |
||||||||||
|
||||||||||
# 4. Clip "predicted x_0" | ||||||||||
if self.clip_sample: | ||||||||||
pred_original_sample = torch.clamp(pred_original_sample, -1, 1) | ||||||||||
|
||||||||||
|
||||||||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix numbering "5. " |
||||||||||
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output | ||||||||||
|
||||||||||
# 7. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix numbering "6. " |
||||||||||
pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction | ||||||||||
|
||||||||||
return pred_post_sample, pred_original_sample | ||||||||||
|
||||||||||
def add_noise( | ||||||||||
self, | ||||||||||
original_samples: torch.Tensor, | ||||||||||
noise: torch.Tensor, | ||||||||||
timesteps: torch.Tensor, | ||||||||||
) -> torch.Tensor: | ||||||||||
|
||||||||||
""" | ||||||||||
Add noise to the original samples. | ||||||||||
|
||||||||||
|
@@ -270,4 +346,4 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor | |||||||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) | ||||||||||
|
||||||||||
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample | ||||||||||
return velocity | ||||||||||
return velocity |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, run
./runtests.sh --autofix
formating issues