JaxStackTraceBeforeTransformation
error with hyper-parameter optimization involving complex dtypes
#13629
Labels
bug
Something isn't working
Description
I am trying to optimize hyper-parameters that involve complex numbers within the model evaluation, resulting in
JaxStackTraceBeforeTransformation: TypeError: Cannot interpret 'Zero(ShapedArray(complex64[10]))' as a data type
being raised.Here is a minimal example:
Here is the full stack-trace
I am at a loss as to how to move forward from here, so any help is greatly appreciated!
What jax/jaxlib version are you using?
0.3.23 / 0.3.22
Which accelerator(s) are you using?
CPU
Additional system info
Mac
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: