Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue]: Test failing with ROCm 6.3.1 on MI250X #120

Open
al-rigazzi opened this issue Jan 29, 2025 · 7 comments
Open

[Issue]: Test failing with ROCm 6.3.1 on MI250X #120

al-rigazzi opened this issue Jan 29, 2025 · 7 comments

Comments

@al-rigazzi
Copy link

Problem Description

I have built flash-attention in a fresh environment with ROCm 6.3.1, running on MI250X, and I am confused by the test results.

I believe that the test file to be used is tests/test_flash_attn_ck.py, as the in the non-ck one, a very large portion of the tests fails.

Nevertheless, this is the output of pytest tests/test_flash_attn_ck.py:

FAILED tests/test_flash_attn_ck.py::test_flash_attn_bwd_overflow[5-16-False-dtype0] - AssertionError: assert 0.0750732421875 <= ((5 * 0.01171875) + 0.001)

I have two questions:

  1. is it normal for this test to fail?
  2. I see that, w.r.t. the standard test_flash_attn.py tests, the tolerance has been raised from a factor 2 to a factor 10, mentioning that bwd needs to be fixed. Does this impact the performances of the library, when used in production?

Operating System

SLES 15-SP5

CPU

AMD EPYC 7A53 64-Core Processor

GPU

AMD Instinct MI250X

ROCm Version

ROCm 6.3.0

ROCm Component

No response

Steps to Reproduce

Torch was installed with

python3 -m pip install --no-cache-dir --pre torch==2.7.0.dev20250128+rocm6.3 --index-url https://download.pytorch.org/whl/nightly/rocm6.3

and repo is at

22c0358 (HEAD -> main, tag: v2.7.3-cktile, origin/main, origin/HEAD)

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

@ppanchad-amd
Copy link

Hi @al-rigazzi. Internal ticket has been created to investigate your issue. Thanks!

@schung-amd
Copy link

Hi @al-rigazzi, thanks for reporting this!

is it normal for this test to fail?

Looking into this; I wouldn't say it's normal for it to fail (in that we don't intend for it to fail and it's not a known issue), but I don't think we run these tests against the nightly torch builds as part of CI. Is this failing for you with other torch wheels? In particular, we have stable torch wheels in https://repo.radeon.com/rocm/manylinux/ that are more likely to have been tested for this than the wheels on pytorch.org.

I see that, w.r.t. the standard test_flash_attn.py tests, the tolerance has been raised from a factor 2 to a factor 10, mentioning that bwd needs to be fixed. Does this impact the performances of the library, when used in production?

This fix is still pending, not aware of any timeline for it. There is no impact on inference time. In theory there could be some impact on training time (more epochs required), but we haven't heard any reports to this effect thus far.

@al-rigazzi
Copy link
Author

Thanks, I will try with the wheels you pointed me to (will have to downgrade to PyTorch 2.5) and report back!

@schung-amd
Copy link

Sorry for the delay, finally had time to try this myself. Reproduced the single test failure with nightly torch 2.7, ROCm 6.3.2, MI210, Ubuntu 24.04, Python 3.12. On the same system, all tests pass with the stable wheels pytorch_triton_rocm-3.0.0+rocm6.3.2.75cc27c26a-cp312-cp312-linux_x86_64.whl and torch-2.4.0+rocm6.3.2-cp312-cp312-linux_x86_64.whl from repo.radeon.com.

As this test failure appears to be related to the CK precision issue you've noted, I suspect using the nightly torch wheel is fine for the purposes of flash attention, but you can also fall back to the stable wheels where the test passes if you wish.

I'll check to see if this failure is already known internally, but if this is caused by the CK precision issue then there isn't much to do until that is addressed.

@al-rigazzi
Copy link
Author

@schung-amd thanks for the help. Unfortunately, I tried using the wheels you suggested and this was the result (using the current version of the repo):

Fatal Python error: Segmentation fault

Current thread 0x0000149a8d9ff700 (most recent call first):
  File "/scratch/flash_attention/flash-attention/flash_attn/flash_attn_interface.py", line 263 in _flash_attn_backward
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/torch/_library/custom_ops.py", line 236 in backend_impl
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/torch/_ops.py", line 672 in redispatch
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/torch/_library/custom_ops.py", line 494 in adinplaceorview_impl
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/torch/_ops.py", line 672 in redispatch
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/torch/_library/autograd.py", line 40 in forward
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/torch/autograd/function.py", line 574 in apply
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/torch/_library/autograd.py", line 98 in autograd_impl
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/torch/_ops.py", line 1061 in __call__
  File "/scratch/flash_attention/flash-attention/flash_attn/flash_attn_interface.py", line 842 in backward
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/torch/autograd/function.py", line 306 in apply

Thread 0x0000149e77462740 (most recent call first):
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/torch/autograd/graph.py", line 768 in _engine_run_backward
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/torch/autograd/__init__.py", line 436 in grad
  File "/scratch/flash_attention/flash-attention/tests/test_flash_attn_ck.py", line 467 in test_flash_attn_output
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/_pytest/python.py", line 159 in pytest_pyfunc_call
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/_pytest/python.py", line 1627 in runtest
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/_pytest/runner.py", line 174 in pytest_runtest_call
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/_pytest/runner.py", line 242 in <lambda>
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/_pytest/runner.py", line 341 in from_call
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/_pytest/runner.py", line 241 in call_and_report
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/_pytest/runner.py", line 132 in runtestprotocol
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/_pytest/runner.py", line 113 in pytest_runtest_protocol
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/_pytest/main.py", line 362 in pytest_runtestloop
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/_pytest/main.py", line 337 in _main
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/_pytest/main.py", line 283 in wrap_session
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/_pytest/main.py", line 330 in pytest_cmdline_main
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/_pytest/config/__init__.py", line 175 in main
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/_pytest/config/__init__.py", line 201 in console_main
  File "/scratch/flash_attention/flashattention_venv/lib/python3.11/site-packages/pytest/__main__.py", line 9 in <module>
  File "<frozen runpy>", line 88 in _run_code
  File "<frozen runpy>", line 198 in _run_module_as_main

Extension modules: numpy._core._multiarray_umath, numpy.linalg._umath_linalg, torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator (total: 18)

Is this something you've already come across?

Thanks!

@schung-amd
Copy link

Interesting, haven't seen that before personally, I'll see if I can find some info on those errors. I used the Python 3.12 wheels, not sure why that would make a difference however. I was also on ROCm 6.3.2, so perhaps this was fixed between versions.

@al-rigazzi
Copy link
Author

Thanks, I tried switching to ROCm 6.3.2 but I still see the same segfault. The first failing test is: tests/test_flash_attn_ck.py::test_flash_attn_output[0.0-1024-1024-59-False-False-False-False-mha-dtype0-False] and the output is

Thread 5 (Thread 0x1551941ff700 (LWP 903784) "pt_autograd_0"):
#0  0x000015549ecdeef6 in ?? () from /software/rocm/6.3.2/lib/libamdhip64.so
#1  0x0000155554c24bd9 in __run_exit_handlers () from /lib64/libc.so.6
#2  0x0000155554c24d6a in exit () from /lib64/libc.so.6
#3  0x0000155341cf3dd2 in fmha_bwd_v3_kernel::fmha_bwd_v3_kernel(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, unsigned char*) () from /scratch/flash_attention/flash-attention/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so
#4  0x0000155341bf16bf in float fmha_bwd_v3_<fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>, fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, false>, fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false> >(ck_tile::stream_config const&, fmha_bwd_args) () from /scratch/flash_attention/flash-attention/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so
#5  0x0000155345a764e0 in mha_bwd(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, float, float, bool, int, int, float, bool, std::optional<at::Generator>, std::optional<at::Tensor>&) () from /scratch/flash_attention/flash-attention/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so
#6  0x0000155345a6fe39 in std::vector<at::Tensor, std::allocator<at::Tensor> > pybind11::detail::argument_loader<at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, float, float, bool, int, int, float, bool, std::optional<at::Generator>, std::optional<at::Tensor>&>::call_impl<std::vector<at::Tensor, std::allocator<at::Tensor> >, std::vector<at::Tensor, std::allocator<at::Tensor> > (*&)(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, float, float, bool, int, int, float, bool, std::optional<at::Generator>, std::optional<at::Tensor>&), 0ul, 1ul, 2ul, 3ul, 4ul, 5ul, 6ul, 7ul, 8ul, 9ul, 10ul, 11ul, 12ul, 13ul, 14ul, 15ul, 16ul, 17ul, 18ul, pybind11::detail::void_type>(std::vector<at::Tensor, std::allocator<at::Tensor> > (*&)(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, float, float, bool, int, int, float, bool, std::optional<at::Generator>, std::optional<at::Tensor>&), std::integer_sequence<unsigned long, 0ul, 1ul, 2ul, 3ul, 4ul, 5ul, 6ul, 7ul, 8ul, 9ul, 10ul, 11ul, 12ul, 13ul, 14ul, 15ul, 16ul, 17ul, 18ul>, pybind11::detail::void_type&&) && () from /scratch/flash_attention/flash-attention/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so
#7  0x0000155345a6f575 in pybind11::cpp_function::initialize<std::vector<at::Tensor, std::allocator<at::Tensor> > (*&)(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, float, float, bool, int, int, float, bool, std::optional<at::Generator>, std::optional<at::Tensor>&), std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, float, float, bool, int, int, float, bool, std::optional<at::Generator>, std::optional<at::Tensor>&, pybind11::name, pybind11::scope, pybind11::sibling, char [14]>(std::vector<at::Tensor, std::allocator<at::Tensor> > (*&)(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, float, float, bool, int, int, float, bool, std::optional<at::Generator>, std::optional<at::Tensor>&), std::vector<at::Tensor, std::allocator<at::Tensor> > (*)(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, float, float, bool, int, int, float, bool, std::optional<at::Generator>, std::optional<at::Tensor>&), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, char const (&) [14])::{lambda(pybind11::detail::function_call&)#1}::operator()(pybind11::detail::function_call&) const () from /scratch/flash_attention/flash-attention/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so
#8  0x0000155345a6f49e in pybind11::cpp_function::initialize<std::vector<at::Tensor, std::allocator<at::Tensor> > (*&)(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, float, float, bool, int, int, float, bool, std::optional<at::Generator>, std::optional<at::Tensor>&), std::vector<at::Tensor, std::allocator<at::Tensor> >, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, float, float, bool, int, int, float, bool, std::optional<at::Generator>, std::optional<at::Tensor>&, pybind11::name, pybind11::scope, pybind11::sibling, char [14]>(std::vector<at::Tensor, std::allocator<at::Tensor> > (*&)(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, float, float, bool, int, int, float, bool, std::optional<at::Generator>, std::optional<at::Tensor>&), std::vector<at::Tensor, std::allocator<at::Tensor> > (*)(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, std::optional<at::Tensor>&, float, float, bool, int, int, float, bool, std::optional<at::Generator>, std::optional<at::Tensor>&), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, char const (&) [14])::{lambda(pybind11::detail::function_call&)#1}::__invoke(pybind11::detail::function_call&) () from /scratch/flash_attention/flash-attention/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so
#9  0x0000155345a661ba in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) () from /scratch/flash_attention/flash-attention/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so
#10 0x000015555515dfcd in cfunction_call (func=<built-in method bwd of PyCapsule object at remote 0x15534b2b9da0>, args=<optimized out>, kwargs=<optimized out>) at Objects/methodobject.c:542
[...]

Unfortunately Python 3.12 is not an option for me, but thanks for double-checking!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants