-
Notifications
You must be signed in to change notification settings - Fork 276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for torch.cuda.amp in VQ-VAE training #65
Comments
I think it will be safer to use fp32 for entire quantize operations. |
So, wrapping |
Yes. It may work. |
Okay! I can make a pull request for this if you want? If not, I can just close this. |
If it is suffice to reproduct the result of fp32 training, definitely it would be nice to have. |
For some reason I can't improve forward pass speed under FP16. (maybe it is bottlenecked by FP32 in quantize operations?) Memory usage is improved though. I'll play around with this a little more and then maybe make a pull request. |
Feature request for AMP support in VQ-VAE training.
So far, I tried naively modifying the
train
function intrain_vqvae.py
like so:The MSE error appears normal, but the latent error becomes infinite.
I'm going to try a few ideas when I have the time. I suspect that half precision and/or scaling doesn't play well with EMA updates. One "workaround" is to replace EMA with the 2nd term in the loss function in the original paper, so as to only update parameters using gradients, but that is far from ideal.
Thanks!
The text was updated successfully, but these errors were encountered: