-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Half precision (float16 or bfloat16) support #539
Comments
Great! This requirement needs to explicitly cast all parameters to |
One more thing that needs to be taken care of is that the coefficients of runge kutta methods should also be cast into |
Could you let me know how to cast the runge kutta coefficients into |
yes, changes should be made in the brainpy framework. Note that |
Update: I think GPU memory consumption is mostly determined by JAX which preallocates 75% of the total GPU memory by default. This may be the reason why I don't see a reduction of memory consumption after switching to FP16. |
The preallocation can be disabled with the setting of |
Hi, when running |
It is supported, but the |
I guess we should just add one more condition in the
|
Yes! |
Does BrainPy fully support half-precision floating point numbers? I have tried to changed some of my own BrainPy code from using
brainpy.math.float32
tobrainpy.math.float16
orbrainpy.math.bfloat16
(by explicitly setting the dtype of all variables and using a debugger to make sure that they won't be promoted tofloat32
), but it seems that the GPU memory consumption and running speed is almost the same as usingfloat32
.The text was updated successfully, but these errors were encountered: