Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Valid DeviceArray slicing in an xmap analysis triggers obsolete slicing exception #6240

Closed
3 tasks done
sjperkins opened this issue Mar 26, 2021 · 2 comments
Closed
3 tasks done
Labels
bug Something isn't working

Comments

@sjperkins
Copy link

sjperkins commented Mar 26, 2021

I'm pretty excited about the new xmap feature and so I took it for a spin.I know xmap is a bit unstable at the moment (or I may have done something incorrectly!) so I thought I'd raise the following issue.

xmap doesn't seem to like the following (although jit is fine with it), triggering what looks like #620.

    l = lm[:, 0, None, None]  # noqa

producing the following exception

IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 0.

This occurs on jax==jax==0.2.11 and jaxlib==0.1.64

I don't think it is #620 as I'll explain at the end of the traceback:

See the full example below:

Please:

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
import jax
import jax.config
from jax.experimental.maps import xmap, mesh
import jax.numpy as np
from jax import jit
import numpy as onp

minus_two_pi_over_c = -2*np.pi/3e8

@jit
def phase_delay(lm, uvw, frequency):
    out_dtype = np.result_type(lm, uvw, frequency, np.complex64)

    one = lm.dtype.type(1.0)
    neg_two_pi_over_c = lm.dtype.type(minus_two_pi_over_c)
    complex_one = out_dtype.type(1j)

    l = lm[:, 0, None, None]  # noqa
    m = lm[:, 1, None, None]

    u = uvw[None, :, 0, None]
    v = uvw[None, :, 1, None]
    w = uvw[None, :, 2, None]

    n = np.sqrt(one - l**2 - m**2) - one

    real_phase = (neg_two_pi_over_c *
                  (l * u + m * v + n * w) *
                  frequency[None, None, :])

    return np.exp(complex_one*real_phase)

if __name__ == "__main__":
    jax.config.update("jax_enable_x64", True)
    src = 10
    row = 1000
    chan = 64
    key = jax.random.PRNGKey(42)

    lm = jax.random.normal(key, (src, 2)) - 0.5
    uvw = (jax.random.normal(key, (row, 3)) - 0.5)*10000
    freq = np.linspace(.856e9, 2*.856e9, chan)

    with mesh(onp.array(jax.devices()) ("row",)):
        xphase = xmap(phase_delay,
                      in_axes=(["source", "lm"],
                               ["row", "uvw"],
                               ["chan"]),
                      out_axes=["source", "row", "chan"],
                      axis_resources={"row": "row"})

        xphase(lm, uvw, freq)
  • If applicable, include full error messages/tracebacks.
$ XLA_FLAGS="--xla_force_host_platform_device_count=8" python  test_xmap.py 
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/experimental/maps.py:411: UserWarning: xmap is an experimental feature and probably has bugs!
  warn("xmap is an experimental feature and probably has bugs!")
Traceback (most recent call last):
  File "test_xmap.py", line 95, in <module>
    xphase(lm, uvw, freq)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/experimental/maps.py", line 524, in fun_mapped
    backend=backend)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/experimental/maps.py", line 651, in bind
    return core.call_bind(self, fun, *args, **params)  # type: ignore
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/core.py", line 1393, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/experimental/maps.py", line 654, in process
    return trace.process_xmap(self, fun, tracers, params)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/core.py", line 600, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/experimental/maps.py", line 539, in xmap_impl
    axis_resources, resource_env, backend, *in_avals)(*args)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/linear_util.py", line 260, in memoized_fun
    ans = call(fun, *args)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/experimental/maps.py", line 554, in make_xmap_callable
    jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1228, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1208, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/api.py", line 382, in f_jitted
    return cpp_jitted_f(context, *args, **kwargs)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/api.py", line 351, in cache_miss_wrapper
    def cache_miss_wrapper(_, *args, **kw): return cache_miss(*args, **kw)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/api.py", line 301, in cache_miss
    donated_invars=donated_invars)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/core.py", line 1402, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/core.py", line 1393, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/core.py", line 1405, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1082, in process_call
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1208, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "test_xmap.py", line 20, in phase_delay
    l = lm[:, 0, None, None]  # noqa
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/core.py", line 527, in __getitem__
    def __getitem__(self, idx): return self.aval._getitem(self, idx)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/_src/numpy/lax_numpy.py", line 4462, in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/_src/numpy/lax_numpy.py", line 4469, in _gather
    indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/_src/numpy/lax_numpy.py", line 4553, in _index_to_gather
    idx = _canonicalize_tuple_index(len(x_shape), idx)
  File "/home/sperkins/venv/jax/lib/python3.6/site-packages/jax/_src/numpy/lax_numpy.py", line 4825, in _canonicalize_tuple_index
    raise IndexError(msg.format(len_without_none, arr_ndim))
IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 0.

What I find interesting is that arr_ndim == 0 so it looks to me like the tracing process is trying to slice a 0-length array. Thus I may be using xmap incorrectly, but I've done some checking. Any thoughts?

@sjperkins sjperkins added the bug Something isn't working label Mar 26, 2021
@sjperkins
Copy link
Author

Ok, I can get this to work and produce my expected result if I do the following:

@jit
def phase_delay(lm, uvw, frequency):
    out_dtype = np.result_type(lm, uvw, frequency, np.complex64)

    one = lm.dtype.type(1.0)
    constant = lm.dtype.type(minus_two_pi_over_c)
    constant = constant * out_dtype.type(1j)

    term = one - lm[0]**2 - lm[1]**2
    term = np.where(term > 0.0, term, 0.0)
    n = np.sqrt(term)

    inner_prod = lm[0]*uvw[0] + lm[1]*uvw[1] + (n - one)*uvw[2]
    return np.exp(constant*inner_prod*frequency)

if __name__ == "__main__":
    jax.config.update("jax_enable_x64", True)
    src = 10
    row = 1000
    chan = 64
    key = jax.random.PRNGKey(42)

    lm = jax.random.normal(key, (src, 2)) - 0.5
    uvw = (jax.random.normal(key, (row, 3)) - 0.5)*10000
    freq = np.linspace(.856e9, 2*.856e9, chan)

    with mesh(onp.array(jax.devices()) ("row",)):
        xphase = xmap(phase_delay,
                      in_axes=(["source", ...],
                               ["row", ...],
                               ["chan"]),
                      out_axes=["source", "row", "chan"],
                      axis_resources={"row": "row"})

        xphase(lm, uvw, freq)

but in the phase_delay function, lm, uvw and frequency are now typed as an array of length 2, array of length 3 and a scalar respectively.

So the primary input axes get removed when the array types are traced in the jitted function. Is this expected?

@sjperkins
Copy link
Author

After playing with vmap a bit more, I understand that the vectorised axis is not present when traced through the vectorised function.

So I understand xmap to be behaving as vmap does.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant