Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make FlaxLMSDiscreteScheduler jittable #3833

Closed

Conversation

Dan8991
Copy link

@Dan8991 Dan8991 commented Jun 20, 2023

PR for issue #2180 trying to make FlaxLMSDiscreteScheduler jittable

@pcuenca
Copy link
Member

pcuenca commented Jun 21, 2023

Hi @Dan8991, let us know when you are ready for a review :)

@Dan8991
Copy link
Author

Dan8991 commented Jun 21, 2023

Yes @pcuenca thank you very much. Right now I am just doing some testing to better understand how the modules work so the commits contain useless files, I will squash everything when I have finished making the class jittable and I will ask you to review the code :).

@Dan8991 Dan8991 force-pushed the jittable_FlaxLMSDiscreteScheduler branch 5 times, most recently from 82365e4 to da495d2 Compare June 27, 2023 15:09
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@Dan8991
Copy link
Author

Dan8991 commented Jun 27, 2023

@pcuenca I think I am almost done with the pull request, however I have some questions in particular:

  1. Since, for example, in pipeline_flax_stable_diffusion.py the loop_body function is called using jax.lax.fori_loop its input and outputs must have the same shape. However this was not the case in the previous implementation of the code since the "latents" variable contains an array ("derivatives") whose length increases over time up to size "order". In the previous code version the derivatives array was initialized with length zero however this raises an error when its size increases since the inputs and outputs of the function don't have matching shapes. A.t.m I have hardcoded the shape to the correct value and I was wondering if it would be possible to add the order argument to the set_timesteps function to be able to avoid this problem.
  2. I believe that there were some bugs in the original implementation of FlaxLMSDiscreteSchedulerState in particular:
  • in the step function in the append and replace functions the axis was not specified which according to the jax documentation leads to the arrays being flattened (which is unwanted in this case) so I added the axis values.
  • in the step function "sigma" and "order" were computed using values from the timestep array, however in the non flax implementation the two variables depend on the index of the considered timestep so I changed it to match the non-flax implementation (which actually made the function work).
    I am not sure if I should open an issue for these bugs instead of solving them in this PR, however they are technically necessary to make FlaxLMSDiscreteScheduler jittable so probably it makes sense to solve them here, no?
  1. I have noticed that whenever the generated image contains NANs the script tells the user that the image was obscured because it might have presented NSFW content, I find this error message to be non informative whenever it is caused by NANs, should I open an issue for this?
    However I am almost ready for a review, I just need to know what would be your preferred way to fix the non-matching input and output sizes that I mentioned in point 1.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Jul 22, 2023
@Dan8991 Dan8991 force-pushed the jittable_FlaxLMSDiscreteScheduler branch from da495d2 to 91bd3e0 Compare July 23, 2023 08:41
@Dan8991
Copy link
Author

Dan8991 commented Jul 23, 2023

@pcuenca I think I am ready for revision, however please check out my last message before starting the review and let me know if my assumptions were correct.

@patrickvonplaten patrickvonplaten requested a review from pcuenca July 24, 2023 18:25
@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@patrickvonplaten
Copy link
Contributor

@pcuenca do you think it makes sense to continue this PR?

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Oct 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants