-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GAN example: Only one backward() call? #594
Comments
good catch. we don't actually support this atm. A solution here is to allow a dictionary of options for each optimizer which allows arbitrary number of calls. def configure_optimizers(self,...):
opt_G = {'optimizer': Adam(...), 'frequency': 2, 'lr_scheduler': LRScheduler(...)}
opt_D = {'optimizer': Adam(...), 'frequency': 1, 'lr_scheduler': LRScheduler(...)}
return opt_G, opt_D Here G would be called twice back to back, and G once after But not sure if this is a clean user experience |
@williamFalcon that API would work for us. @alainjungo: Some workarounds, none of them ideal:
|
You can always return one loss at the training step that captures both losses. Not sure if I'm correct here, but this seems equivalent and matches PTL paradigms. |
I have implemented an API that allows returning optimizer, lr_schedulers, optimizer_frequencies, If there is an agreement on this API, I'll proceed to testing, documenting and submitting a PR. Another option would be to allow returning a tuple of dictionaries as @williamFalcon suggested. that would be a minor change for me and I am willing to that if it is agreed upon. |
I like the @williamFalcon which seems clear to me... |
I'll implement the @williamFalcon API and send a detailed PR over the weekend 👍 |
In the PyTorch GAN tutorial https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html there are two backward() calls for the discriminator. How do you ensure this with your structure, where backward() gets called after the training step?
Best,
Alain
The text was updated successfully, but these errors were encountered: