Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JaxStackTraceBeforeTransformation error with hyper-parameter optimization involving complex dtypes #13629

Closed
LouisDesdoigts opened this issue Dec 13, 2022 · 3 comments
Labels
bug Something isn't working

Comments

@LouisDesdoigts
Copy link

Description

I am trying to optimize hyper-parameters that involve complex numbers within the model evaluation, resulting in JaxStackTraceBeforeTransformation: TypeError: Cannot interpret 'Zero(ShapedArray(complex64[10]))' as a data type being raised.

Here is a minimal example:

import jax
import jax.numpy as np

def model_fn(x):
    """Create complex array using input"""
    array = np.linspace(-x, x, 10)
    return np.abs(1j*array) # No error thrown if the `1j*` is removed

@jax.value_and_grad
def inner_loss_fn(x, data):
    return np.square(model_fn(x) - data).sum()

@jax.value_and_grad
def loss_fn(lr, x, data):
    for i in range(10):
        inner_loss, grad = inner_loss_fn(x, data)
        x = x - lr * grad
    return inner_loss

data = model_fn(0.)
inner_loss, inner_grad = inner_loss_fn(1., data) # No error thrown
loss, grad = loss_fn(1., 1., data) # Error thrown

Here is the full stack-trace

Traceback (most recent call last):
  File "/Users/louis/PhD/dLux/sandbox/Deconvolution/example.py", line 15, in <module>
    @jax.value_and_grad
  File "/Users/louis/PhD/dLux/sandbox/Deconvolution/example.py", line 17, in loss_fn
    for i in range(10):
  File "/Users/louis/PhD/dLux/sandbox/Deconvolution/example.py", line 13, in inner_loss_fn
    return np.square(model_fn(x) - data).sum()
  File "/Users/louis/PhD/dLux/sandbox/Deconvolution/example.py", line 7, in model_fn
    array = np.linspace(-x, x, 10)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/numpy/ufuncs.py", line 195, in absolute
    return x if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: Cannot interpret 'Zero(ShapedArray(complex64[10]))' as a data type

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/louis/PhD/dLux/sandbox/Deconvolution/example.py", line 24, in <module>
    loss, grad = loss_fn(1., 1., data)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/api.py", line 1156, in value_and_grad_f
    g = vjp_py(lax_internal._one(ans))
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/tree_util.py", line 292, in __call__
    return self.fun(*args, **kw)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/api.py", line 2523, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/tree_util.py", line 292, in __call__
    return self.fun(*args, **kw)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/interpreters/ad.py", line 140, in unbound_vjp
    arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/interpreters/ad.py", line 240, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/interpreters/ad.py", line 625, in call_transpose
    out_flat = primitive.bind(fun, *all_args, **params)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/core.py", line 1939, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/core.py", line 1955, in call_bind
    outs = top_trace.process_call(primitive, fun_, tracers, params)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/core.py", line 701, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/dispatch.py", line 234, in _xla_call_impl
    compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/linear_util.py", line 309, in memoized_fun
    ans = call(fun, *args)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/dispatch.py", line 342, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/profiler.py", line 313, in wrapper
    return func(*args, **kwargs)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/dispatch.py", line 428, in lower_xla_callable
    jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/profiler.py", line 313, in wrapper
    return func(*args, **kwargs)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2080, in trace_to_jaxpr_final2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2030, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/interpreters/ad.py", line 246, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 2017, in _conj_transpose_rule
    return [conj(t)]
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/lax/lax.py", line 407, in conj
    return conj_p.bind(x, input_dtype=_dtype(x))
  File "/Users/louis/mambaforge/envs/dlux/lib/python3.10/site-packages/jax/_src/dtypes.py", line 445, in dtype
    dt = np.result_type(x)
  File "<__array_function__ internals>", line 180, in result_type
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Cannot interpret 'Zero(ShapedArray(complex64[10]))' as a data type

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/louis/PhD/dLux/sandbox/Deconvolution/example.py", line 24, in <module>
    loss, grad = loss_fn(1., 1., data)
  File "<__array_function__ internals>", line 180, in result_type
TypeError: Cannot interpret 'Zero(ShapedArray(complex64[10]))' as a data type

I am at a loss as to how to move forward from here, so any help is greatly appreciated!

What jax/jaxlib version are you using?

0.3.23 / 0.3.22

Which accelerator(s) are you using?

CPU

Additional system info

Mac

NVIDIA GPU info

No response

@LouisDesdoigts LouisDesdoigts added the bug Something isn't working label Dec 13, 2022
@Justin-Tan
Copy link

Running into a similar issue for model architectures which involve manipulations of complex numbers.

@rajasekharporeddy
Copy link
Contributor

Hi @LouisDesdoigts

I tried to run the mentioned code in Google Colab with latest JAX version (0.4.23). The mentioned code executed without any error. Please refer to the gist.

I have tried this code in mac with latest JAX version i.e., 0.4.23 and it executed without error. Please find the screenshot for reference.
image

Please check with the latest version and confirm if you still have the issue.

Thank you.

@LouisDesdoigts
Copy link
Author

Hey @rajasekharporeddy you are correct it looks like this issue has been fixed in some previous version, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants