We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
As of jax 0.4.27, arguments to jax.debug.callback are now jax.Array rather than np.ndarray (see https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-27-may-7-2024).
jax.debug.callback
jax.Array
np.ndarray
This break the tqdm updates with error TypeError: unsupported type for timedelta seconds component: jaxlib.xla_extension.ArrayImpl.
TypeError: unsupported type for timedelta seconds component: jaxlib.xla_extension.ArrayImpl
It looks like the fix would be to use jax.tree.map(np.asarray, args) on the callback args.
jax.tree.map(np.asarray, args)
The text was updated successfully, but these errors were encountered:
Closing as now fixed in latest release
Sorry, something went wrong.
No branches or pull requests
As of jax 0.4.27, arguments to
jax.debug.callback
are nowjax.Array
rather thannp.ndarray
(see https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-27-may-7-2024).This break the tqdm updates with error
TypeError: unsupported type for timedelta seconds component: jaxlib.xla_extension.ArrayImpl
.It looks like the fix would be to use
jax.tree.map(np.asarray, args)
on the callback args.The text was updated successfully, but these errors were encountered: