Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix build when flash attention and memory efficient attention are disabled #18761

Merged
merged 11 commits into from
Dec 26, 2023

Conversation

pengwa
Copy link
Contributor

@pengwa pengwa commented Dec 8, 2023

Fix build when flash attention and memory efficient attention are disabled

On a customer env with lower version of CUDA < 11.6. Both flash attention and memory efficient attention is turned OFF according to

set(onnxruntime_USE_FLASH_ATTENTION OFF)
. So
if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION)
condition check return false. No cutlass lib is built.

Turn off flash attention since CUDA compiler version < 11.6

While, the kernels in https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/contrib_ops/cuda/moe/ft_moe are depending on cutass for its build, so we get error like this:

[ 77%] Building CUDA object CMakeFiles/onnxruntime_providers_cuda.dir/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu.o
In file included from /tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu:17:
/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h:23:10: fatal error: cutlass/array.h: No such file or directory
   23 | #include "cutlass/array.h"
      |          ^~~~~~~~~~~~~~~~~
compilation terminated.
In file included from /tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu:17:
/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h:23:10: fatal error: cutlass/array.h: No such file or directory
   23 | #include "cutlass/array.h"
      |          ^~~~~~~~~~~~~~~~~
compilation terminated.
In file included from /tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu:17:
/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h:23:10: fatal error: cutlass/array.h: No such file or directory
   23 | #include "cutlass/array.h"
      |          ^~~~~~~~~~~~~~~~~
compilation terminated.
In file included from /tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu:17:
/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h:23:10: fatal error: cutlass/array.h: No such file or directory
   23 | #include "cutlass/array.h"
      |          ^~~~~~~~~~~~~~~~~
compilation terminated.
fatal   : Could not open input file /tmp/tmpxft_00044da3_00000000-11_moe_gemm_kernels_fp16_fp16.compute_60.cpp1.ii
make[2]: *** [CMakeFiles/onnxruntime_providers_cuda.dir/build.make:6290: CMakeFiles/onnxruntime_providers_cuda.dir/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu.o] Error 1
make[2]: *** Waiting for unfinished jobs....
make[1]: *** [CMakeFiles/Makefile2:2210: CMakeFiles/onnxruntime_providers_cuda.dir/all] Error 2
make: *** [Makefile:166: all] Error 2
Traceback (most recent call last):
  File "/tmp/onnxruntime/tools/ci_build/build.py", line 2746, in <module>
    sys.exit(main())
  File "/tmp/onnxruntime/tools/ci_build/build.py", line 2639, in main
    build_targets(args, cmake_path, build_dir, configs, num_parallel_jobs, args.target)
  File "/tmp/onnxruntime/tools/ci_build/build.py", line 1527, in build_targets
    run_subprocess(cmd_args, env=env)
  File "/tmp/onnxruntime/tools/ci_build/build.py", line 824, in run_subprocess
    return run(*args, cwd=cwd, capture_stdout=capture_stdout, shell=shell, env=my_env)
  File "/tmp/onnxruntime/tools/python/util/run.py", line 49, in run
    completed_process = subprocess.run(
  File "/opt/conda/lib/python3.8/subprocess.py", line 516, in run
    raise CalledProcessError(retcode, process.args,

Motivation and Context

To summarize, there are two cases we will have build failure for Linux CUDA build:

  1. User use cuda version < 11.6
  2. User disabled Flash attention and memory efficient attention explictly with onnxruntime_USE_FLASH_ATTENTION and onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION

@pengwa pengwa requested review from wangyems and askhade December 8, 2023 13:00
cmake/CMakeLists.txt Outdated Show resolved Hide resolved
@askhade
Copy link
Contributor

askhade commented Dec 11, 2023

LGTM.
Would it be easier to make the change in cmake to remove all the files from moe folder if cutlass is not available instead of adding the #ifdef in every file?

@pengwa
Copy link
Contributor Author

pengwa commented Dec 12, 2023

LGTM. Would it be easier to make the change in cmake to remove all the files from moe folder if cutlass is not available instead of adding the #ifdef in every file?
Thanks for the review.

Oh, I thought adding #ifdef would make the cutlass dependency super clear for the owner or readers. Ideally, if the kernel want, they can use cutlass implementation or other fallback implementation (when without cutlass).

Do the removal trick in the cmake file, might be a bit implicit, but anyway, I can do it if you strongly suggesting doing this way.

@pengwa
Copy link
Contributor Author

pengwa commented Dec 26, 2023

Thank @tianleiwu !

@pengwa pengwa merged commit 37f7436 into main Dec 26, 2023
92 of 100 checks passed
@pengwa pengwa deleted the pengwa/fix_build_007 branch December 26, 2023 00:57
tianleiwu added a commit that referenced this pull request Jan 25, 2024
tianleiwu added a commit that referenced this pull request Jan 26, 2024
### Description
Since Cutlass can be built with CUDA 11.4 (The minimum CUDA version for
onnxruntime CUDA build), there is no need to have a flag to disable
cutlass.

Changes:
(1) Reverted #18761
(2) remove the condition to build cutlass.
(3) Fix a few build errors or warnings during testing CUDA 11.4 build. 

Note that SM 89 and 90 (including fp8) requires CUDA 11.8 or later.
Flash attention and cutlass fused multihead attention will not be built
for CUDA < 11.6. It is recommended to use CUDA 11.8 or above to build if
you want to support latest GPUs.

It is better to include it in 1.17.0 (otherwise, the release branch
might encounter build failure with CUDA 11.4).

Tests:
(1) Build with flash attention and efficient attention off: **passed**
(2) Build with CUDA 11.4: **passed**

Example build command used in Ubuntu 20.04:
```
export CUDA_HOME=/usr/local/cuda-11.4
export CUDNN_HOME=/usr/lib/x86_64-linux-gnu/
export CUDACXX=/usr/local/cuda-11.4/bin/nvcc

sh build.sh --config Release  --build_shared_lib --parallel  --use_cuda --cuda_version 11.4 \
            --cuda_home $CUDA_HOME --cudnn_home $CUDNN_HOME --build_wheel --skip_tests \
            --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \
            --disable_types float8
```

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
YUNQIUGUO pushed a commit that referenced this pull request Jan 26, 2024
### Description
Since Cutlass can be built with CUDA 11.4 (The minimum CUDA version for
onnxruntime CUDA build), there is no need to have a flag to disable
cutlass.

Changes:
(1) Reverted #18761
(2) remove the condition to build cutlass.
(3) Fix a few build errors or warnings during testing CUDA 11.4 build. 

Note that SM 89 and 90 (including fp8) requires CUDA 11.8 or later.
Flash attention and cutlass fused multihead attention will not be built
for CUDA < 11.6. It is recommended to use CUDA 11.8 or above to build if
you want to support latest GPUs.

It is better to include it in 1.17.0 (otherwise, the release branch
might encounter build failure with CUDA 11.4).

Tests:
(1) Build with flash attention and efficient attention off: **passed**
(2) Build with CUDA 11.4: **passed**

Example build command used in Ubuntu 20.04:
```
export CUDA_HOME=/usr/local/cuda-11.4
export CUDNN_HOME=/usr/lib/x86_64-linux-gnu/
export CUDACXX=/usr/local/cuda-11.4/bin/nvcc

sh build.sh --config Release  --build_shared_lib --parallel  --use_cuda --cuda_version 11.4 \
            --cuda_home $CUDA_HOME --cudnn_home $CUDNN_HOME --build_wheel --skip_tests \
            --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \
            --disable_types float8
```

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants