Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Crash due to out-of-range index access #18106

Closed
hawkinsp opened this issue Oct 13, 2023 Discussed in #18103 · 1 comment
Closed

Crash due to out-of-range index access #18106

hawkinsp opened this issue Oct 13, 2023 Discussed in #18103 · 1 comment

Comments

@hawkinsp
Copy link
Collaborator

Discussed in #18103

Originally posted by DanPuzzuoli October 13, 2023
I'm trying to run a jit compiled gradient and I'm getting the following error:

  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 177, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/pjit.py", line 256, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
                                                 ^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/pjit.py", line 167, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/core.py", line 2657, in bind
    return self.bind_with_trace(top_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/core.py", line 389, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/core.py", line 869, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/pjit.py", line 1212, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/pjit.py", line 1196, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
                         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/pjit.py", line 1132, in _pjit_call_impl_python
    lowering_parameters=mlir.LoweringParameters()).compile()
                                                   ^^^^^^^^^
  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2276, in compile
    executable = UnloadedMeshExecutable.from_hlo(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2624, in from_hlo
    xla_executable, compile_options = _cached_compilation(
                                      ^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2531, in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/compiler.py", line 294, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/devEnv311/lib/python3.11/site-packages/jax/_src/compiler.py", line 256, in backend_compile
    return backend.compile(built_c, compile_options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: vector

This error does not get raised if I don't try to jit the gradient function, which makes it difficult to track down what's causing the error. I'm still trying to find a minimal example, but wanted to ask here in case anyone has any insight.

On my Macbook this produces:

libc++abi: terminating due to uncaught exception of type std::out_of_range: vector
Abort trap: 6

with this lldb backtrace:

(lldb) bt
* thread #1, queue = 'com.apple.main-thread', stop reason = signal SIGABRT
  * frame #0: 0x000000018d178744 libsystem_kernel.dylib`__pthread_kill + 8
    frame #1: 0x000000018d1afc28 libsystem_pthread.dylib`pthread_kill + 288
    frame #2: 0x000000018d0bdae8 libsystem_c.dylib`abort + 180
    frame #3: 0x000000018d168b84 libc++abi.dylib`abort_message + 132
    frame #4: 0x000000018d1583b4 libc++abi.dylib`demangling_terminate_handler() + 320
    frame #5: 0x000000018ce2ee68 libobjc.A.dylib`_objc_terminate() + 160
    frame #6: 0x000000018d167f48 libc++abi.dylib`std::__terminate(void (*)()) + 16
    frame #7: 0x000000018d16ad34 libc++abi.dylib`__cxxabiv1::failed_throw(__cxxabiv1::__cxa_exception*) + 36
    frame #8: 0x000000018d16ace0 libc++abi.dylib`__cxa_throw + 140
    frame #9: 0x000000018d0e371c libc++.1.dylib`std::__1::__throw_out_of_range[abi:v15006](char const*) + 72
    frame #10: 0x000000018d0e7318 libc++.1.dylib`std::__1::__vector_base_common<true>::__throw_out_of_range() const + 24
    frame #11: 0x000000015defc80c xla_extension.so`std::__1::__vector_base<xla::Shape, std::__1::allocator<xla::Shape>>::__throw_out_of_range() const + 12
    frame #12: 0x000000015defb1c8 xla_extension.so`xla::Shape::tuple_shapes(int) const + 72
    frame #13: 0x000000015df03444 xla_extension.so`xla::ShapeUtil::GetSubshape(xla::Shape const&, absl::lts_20230125::Span<long long const>) + 72
    frame #14: 0x000000015de14c80 xla_extension.so`xla::MutableLiteralBase::CopyFrom(xla::LiteralSlice const&, xla::ShapeIndex const&, xla::ShapeIndex const&, bool) + 156
    frame #15: 0x000000015ad22a50 xla_extension.so`xla::HloEvaluator::HandleGetTupleElement(xla::HloInstruction const*) + 744
    frame #16: 0x000000015db94378 xla_extension.so`absl::lts_20230125::Status xla::HloInstruction::Accept<xla::HloInstruction const*>(xla::DfsHloVisitorBase<xla::HloInstruction const*>*, bool, bool, bool) + 1192
    frame #17: 0x0000000159fcc41c xla_extension.so`absl::lts_20230125::Status xla::HloComputation::Accept<xla::HloInstruction const*>(xla::DfsHloVisitorBase<xla::HloInstruction const*>*) const + 388
    frame #18: 0x000000015ad0c2bc xla_extension.so`xla::HloEvaluator::Evaluate(xla::HloComputation const&, absl::lts_20230125::Span<xla::Literal const* const>) + 936
    frame #19: 0x000000015ad25140 xla_extension.so`xla::HloEvaluator::HandleConditional(xla::HloInstruction const*) + 304
    frame #20: 0x000000015ad0d334 xla_extension.so`xla::HloEvaluator::EvaluateInternal(xla::HloInstruction const*, xla::ShapeIndex const&, bool) + 512
    frame #21: 0x000000015ad0cfd0 xla_extension.so`xla::HloEvaluator::Evaluate(xla::HloInstruction const*, bool) + 228
    frame #22: 0x000000015ad0da84 xla_extension.so`xla::HloEvaluator::TryEvaluate(xla::HloInstruction const*, xla::Literal*, bool) + 44
    frame #23: 0x000000015ac67538 xla_extension.so`xla::HloConstantFolding::Run(xla::HloModule*, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&) + 732
    frame #24: 0x000000015afc07dc xla_extension.so`xla::HloPassPipeline::RunHelper(xla::HloPassInterface*, xla::HloModule*, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&) + 52
    frame #25: 0x000000015afc05d0 xla_extension.so`absl::lts_20230125::StatusOr<bool> xla::HloPassPipeline::RunPassesInternal<xla::HloModule>(xla::HloModule*, xla::DebugOptions const&, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&)::'lambda'(xla::HloPassInterface*, xla::HloModule*, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&)::operator()(xla::HloPassInterface*, xla::HloModule*, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&) const + 56
    frame #26: 0x000000015afbe178 xla_extension.so`absl::lts_20230125::StatusOr<bool> xla::HloPassPipeline::RunPassesInternal<xla::HloModule>(xla::HloModule*, xla::DebugOptions const&, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&) + 912
    frame #27: 0x000000015afbdcd0 xla_extension.so`xla::HloPassPipeline::Run(xla::HloModule*, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&) + 100
    frame #28: 0x000000015a485df8 xla_extension.so`xla::HloPassFix<xla::HloPassPipeline, 25>::RunOnChangedComputationsOnce(xla::HloModule*, xla::HloPassInterface::RunState*, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&) + 68
    frame #29: 0x000000015a485aac xla_extension.so`xla::HloPassFix<xla::HloPassPipeline, 25>::RunToFixPoint(xla::HloModule*, xla::HloPassInterface::RunState*, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&) + 144
    frame #30: 0x000000015a485448 xla_extension.so`xla::HloPassFix<xla::HloPassPipeline, 25>::Run(xla::HloModule*, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&) + 404
    frame #31: 0x000000015afc07dc xla_extension.so`xla::HloPassPipeline::RunHelper(xla::HloPassInterface*, xla::HloModule*, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&) + 52
    frame #32: 0x000000015afc05d0 xla_extension.so`absl::lts_20230125::StatusOr<bool> xla::HloPassPipeline::RunPassesInternal<xla::HloModule>(xla::HloModule*, xla::DebugOptions const&, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&)::'lambda'(xla::HloPassInterface*, xla::HloModule*, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&)::operator()(xla::HloPassInterface*, xla::HloModule*, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&) const + 56
    frame #33: 0x000000015afbe178 xla_extension.so`absl::lts_20230125::StatusOr<bool> xla::HloPassPipeline::RunPassesInternal<xla::HloModule>(xla::HloModule*, xla::DebugOptions const&, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&) + 912
    frame #34: 0x000000015afbdcd0 xla_extension.so`xla::HloPassPipeline::Run(xla::HloModule*, absl::lts_20230125::flat_hash_set<std::__1::basic_string_view<char, std::__1::char_traits<char>>, absl::lts_20230125::container_internal::StringHash, absl::lts_20230125::container_internal::StringEq, std::__1::allocator<std::__1::basic_string_view<char, std::__1::char_traits<char>>>> const&) + 100
    frame #35: 0x000000015a46d7c4 xla_extension.so`xla::cpu::CpuCompiler::RunHloPassesThroughLayoutAssn(xla::HloModule*, bool, xla::cpu::LLVMTargetMachineFeatures*, bool) + 5368
    frame #36: 0x000000015a4702b4 xla_extension.so`xla::cpu::CpuCompiler::RunHloPasses(xla::HloModule*, bool, llvm::TargetMachine*, bool) + 100
    frame #37: 0x000000015a470508 xla_extension.so`xla::cpu::CpuCompiler::RunHloPasses(std::__1::unique_ptr<xla::HloModule, std::__1::default_delete<xla::HloModule>>, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&) + 468
    frame #38: 0x000000015a3fcaf8 xla_extension.so`xla::TfrtCpuClient::Compile(xla::XlaComputation const&, xla::CompileOptions) + 1600
copybara-service bot pushed a commit to openxla/xla that referenced this issue Oct 15, 2023
The reproduction comes from jax-ml/jax#18103 but I'm having trouble reproducing it in a unit test.

If this code path is triggered, a tuple with the incorrect size is returned.

Fixes jax-ml/jax#18106

PiperOrigin-RevId: 573666720
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 15, 2023
The reproduction comes from jax-ml/jax#18103 but I'm having trouble reproducing it in a unit test.

If this code path is triggered, a tuple with the incorrect size is returned.

Fixes jax-ml/jax#18106

PiperOrigin-RevId: 573666720
@hawkinsp
Copy link
Collaborator Author

openxla/xla#6344 should have fixed this, and the fix should be present in the next jaxlib.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant