reporting error from parallel computing using blackjax: RWState(position='ShapedArray(float32[1])', logdensity='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])'). #695
-
Hi blackjax team, I just came to some issues with addressing parallel computing using blackjax, about float64 and float32. I know it is not an issue from blackjax, but I am just interested in how to modify blackjax function...
If impmented without parallel, everything works fine...
However, if using joblib package, it does report the error: If use jax.config.update("jax_enable_x64", False), the parallel works well.However, if using jax.config.update("jax_enable_x64", True), it just has this error. I am not sure whether this issue is easily addressed from the balckjax functions above. For example, can we manually set RWState as float64 or anything else... I am looking forward to valuable comments from you guys! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hi, for parallelization you should use the native JAX functionary like However, the error looks like it's just coming from the problem that float64 set up was not correct. You need to put |
Beta Was this translation helpful? Give feedback.
it meant the output of the logdensity function
Returns float32 initially, and after 1 step it returns float64.
Did you try also casting
xobs
tonp.float64
?