-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LMSDiscreteSchedulerTest #467
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @sidthekidder thanks a lot for diving into these tests! The LMS scheduler differs quite a bit in its API from other schedulers (and needs refactoring, to be honest), so this is quite a challenge!
I've left some suggestions to make the full_loop
test work, let me know if something's not clear :)
In general, since this scheduler is unique to Stable Diffusion at the moment, feel free to reference the SD pipeline for how it's used in practice.
tests/test_scheduler.py
Outdated
self.check_over_configs(beta_start=beta_start, beta_end=beta_end) | ||
|
||
def test_schedules(self): | ||
for schedule in ["linear"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check over all shedules that are currently supported / in use
for schedule in ["linear"]: | |
for schedule in ["linear", "scaled_linear"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to add scaled_linear
(squaredcos_cap_v2
is not implemented for LMSDiscreteScheduler)
tests/test_scheduler.py
Outdated
|
||
num_trained_timesteps = len(scheduler) | ||
|
||
model = self.dummy_model() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_trained_timesteps = len(scheduler) | |
model = self.dummy_model() | |
model = self.dummy_model() | |
num_inference_steps = 10 | |
scheduler.set_timesteps(num_inference_steps) |
This scheduler requires a .set_timesteps()
call. Reference: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L225
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
tests/test_scheduler.py
Outdated
num_trained_timesteps = len(scheduler) | ||
|
||
model = self.dummy_model() | ||
sample = self.dummy_sample_deter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sample = self.dummy_sample_deter | |
sample = self.dummy_sample_deter * scheduler.sigmas[0] |
The initial sample needs to be multiplied by max sigma to make sure that it's within the supported range. Ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L229
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, updated
tests/test_scheduler.py
Outdated
model = self.dummy_model() | ||
sample = self.dummy_sample_deter | ||
|
||
for t in reversed(range(num_trained_timesteps - 1)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for t in reversed(range(num_trained_timesteps - 1)): | |
for i, t in enumerate(scheduler.timesteps): |
Can't get rid of enumeration yet, need both i
and t
for the loop in the current implementation 😅
tests/test_scheduler.py
Outdated
# 1. predict noise residual | ||
residual = model(sample, t) | ||
# print("residual: ") | ||
# print(residual) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# 1. predict noise residual | |
residual = model(sample, t) | |
# print("residual: ") | |
# print(residual) | |
with torch.no_grad(): | |
sigma = scheduler.sigmas[i] | |
sample = sample / ((sigma**2 + 1) ** 0.5) | |
model_output = model(sample, t) |
Need to rescale the model input here too, to conform to the ODE equation used in K-LMS. Ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L246
tests/test_scheduler.py
Outdated
# 2. predict previous mean of sample x_t-1 | ||
pred_prev_sample = scheduler.step(residual, t, sample).prev_sample | ||
|
||
# if t > 0: | ||
# noise = self.dummy_sample_deter | ||
# variance = scheduler.get_variance(t) ** (0.5) * noise | ||
# | ||
# sample = pred_prev_sample + variance | ||
sample = pred_prev_sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# 2. predict previous mean of sample x_t-1 | |
pred_prev_sample = scheduler.step(residual, t, sample).prev_sample | |
# if t > 0: | |
# noise = self.dummy_sample_deter | |
# variance = scheduler.get_variance(t) ** (0.5) * noise | |
# | |
# sample = pred_prev_sample + variance | |
sample = pred_prev_sample | |
output = scheduler.step(model_output, i, sample) | |
sample = output.prev_sample |
The LMS sampler takes i
instead of an explicit timestep
tests/test_scheduler.py
Outdated
assert abs(result_sum.item() - 259.0883) < 1e-2 | ||
assert abs(result_mean.item() - 0.3374) < 1e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert abs(result_sum.item() - 259.0883) < 1e-2 | |
assert abs(result_mean.item() - 0.3374) < 1e-3 | |
assert abs(result_sum.item() - 1006.3885) < 1e-2 | |
assert abs(result_mean.item() - 1.3104) < 1e-3 |
These are the reference values if all of the above suggestions are applied, make sure they match on your hardware! (if not - we'll need to adjust)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adjusted values as necessary with the new output
Thanks for the detailed comments @anton-l! I didn't realize we could follow the implementation in Let me know if I am missing any test combinations or edge cases. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @sidthekidder! I think all that's left is to implement test_pytorch_equal_numpy()
, could you add that? :)
And feel free to modify the scheduler if it fails this test for some reason!
d340b18
to
629b934
Compare
Added |
Nice looks good to me :-) |
@anton-l if you could review one more time this would be great :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll just remove the unused past residuals (they're needed only in the PNDM tests) and we're good :)
Thank you @sidthekidder! |
* [WEB] CSS changes to the web-ui (huggingface#465) This commit updates UI with styling. Signed-Off-by: Gaurav Shukla <[email protected]> Signed-off-by: Gaurav Shukla <[email protected]> * [WEB] Update the title (huggingface#466) * [WEB] Add support for long prompts (huggingface#467) * [WEB] fix background color Signed-Off-by: Gaurav Shukla * [WEB] Remove long prompts support It removes support to long prompts due to higher lag in loading long prompts. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs> * [WEB] Update nod logo and enable debug feature. Signed-Off-by: Gaurav Shukla <[email protected]> Signed-off-by: Gaurav Shukla <[email protected]> Signed-off-by: Gaurav Shukla Signed-off-by: Gaurav Shukla <gaurav@nod-labs>
PR for #338, still requires some fixes