Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there any way to compile the codes with nvcc debug flag(-G)? #1364

Open
Dev-Jahn opened this issue Dec 2, 2024 · 6 comments
Open

Is there any way to compile the codes with nvcc debug flag(-G)? #1364

Dev-Jahn opened this issue Dec 2, 2024 · 6 comments

Comments

@Dev-Jahn
Copy link

Dev-Jahn commented Dec 2, 2024

I'm trying to implement custom behaviors with flash-attn 3 (hopper) base.
There's no problem with building library in general, but compile takes too much time when adding nvcc -G flag (or --ptxas-options=-g) to debug the mainloop and tile schedulers.

nvcc_flags = [
    "-std=c++17",
    # "-U__CUDA_NO_HALF_OPERATORS__",
    # "-U__CUDA_NO_HALF_CONVERSIONS__",
    "-U__CUDA_NO_BFLOAT16_OPERATORS__",
    "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
    "-U__CUDA_NO_BFLOAT162_OPERATORS__",
    "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
    "--expt-relaxed-constexpr",
    "--expt-extended-lambda",
    "--use_fast_math",
    # "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage",  # printing out number of registers
    # "-lineinfo",
    "-DCUTLASS_DEBUG_TRACE_LEVEL=0",
    "-DNDEBUG",
    "-gencode",
    "arch=compute_90a,code=sm_90a",
    "--threads",
    "16",
    "-g",  # Host code debug
    "-O0",  # Host code opt level
    "-G",  # Device code debug
    "-DFLASH_DEBUG",
]

Above is list of all nvcc flags I'm using.

  • Problem is,
  1. Adding -g -O0 to cxx_flags is compilable in reasonable time
  2. Adding -g -O0 to nvcc_flags is compilable in reasonable time
  3. Adding -G(device code debug) to nvcc_flags never finishes
  • What I've tried
  1. Make /tmp dir ramdisk
  2. Only compile the minimum cu files to reduce instantiation
if debug_mode:
    sources = [
        "flash_api.cpp",
        "flash_fwd_hdim64_bf16_gqa4_sm90.cu",
        "flash_bwd_hdim64_bf16_sm90.cu",
    ]
  1. Narrow down some switches by explicit instantiation
// flash_api.cpp
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream,
                 bool force_split_kernel = false) {

    int dtype = 1;
    if (params.is_bf16) {
        dtype = 2;
    } else if (params.is_e4m3) {
        dtype = 3;
    }
#ifdef FLASH_DEBUG
    run_mha_fwd_gqa_<cutlass::bfloat16_t, 64, 4>(params, stream);
#else
    PREC_SWITCH(dtype, Element, [&] {
        HEADDIM_SWITCH(params.d, kHeadSize, [&] {
            if (!params.use_gqa_packing) {
                run_mha_fwd_<Element, kHeadSize>(params, stream);
            } else {
                QUERYHEAD_SWITCH(params.h_h_k_ratio, kBlockH, [&] {
                    run_mha_fwd_gqa_<Element, kHeadSize, kBlockH>(params,
                                                                  stream);
                });
            }
        });
    });
#endif

// flash_fwd_launch_template.h
template<typename T, int kBlockH>
void run_mha_fwd_hdim64_gqa(Flash_fwd_params &params, cudaStream_t stream) {
    constexpr static int Headdim = 64;
    constexpr static bool UseCluster = false;
    using Seqlen_traits = flash::FixedSeqLenTraits;
    using Seqlen_traits_Q = flash::FixedGQASeqLenTraits;
#ifdef FLASH_DEBUG
    constexpr static int kNumMmaWGs = 2;
    run_flash_fwd<
              Flash_fwd_kernel_traits<Headdim,
              /*kBlockM_=*/kNumMmaWGs * 64,
              /*kBlockN_=*/128,
              /*kNWarps_=*/4 + kNumMmaWGs * 4,
              /*kStages_=*/2,
              /*Is_Q_in_regs_=*/false,
              /*kClusterM_=*/UseCluster ? 2 : 1,
              /*elem_type=*/T,
              /*Is_split_=*/false,
              /*kBlockH_=*/kBlockH>,
              /*Is_causal=*/false,
              /*Is_local=*/true,
              Seqlen_traits,
              Seqlen_traits_Q
            >(params, stream);
#else
    MMA_3WG_SWITCH(kBlockH * params.seqlen_q, kNumMmaWGs, [&] {

My dev machine has 224 CPU cores but increasing ninja or nvcc threads is meaningless cuz cicc and ptxas is not parallelizable.

ptxas process for single ptx file takes almost forever(more than 2 hours).

I'm kinda new to CUDA, so I might miss some important options.
Is there any other way to reduce the compilation time when adding device debug flag?

(edit)

      [1/3] c++ -MMD -MF /tmp/tmpo2nq5vjf.build-temp/src/flash_api.o.d -pthread -B /home/ubuntu/anaconda3/envs/flash-dev/compiler_compat -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /home/ubuntu/anaconda3/envs/flash-dev/include -fPIC -O2 -isystem /home/ubuntu/anaconda3/envs/flash-dev/include -fPIC -I/home/ubuntu/workspace/flash-attention/csrc/cutlass/include -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include/TH -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/ubuntu/anaconda3/envs/flash-dev/include/python3.11 -c -c /home/ubuntu/workspace/flash-attention/2d/src/flash_api.cpp -o /tmp/tmpo2nq5vjf.build-temp/src/flash_api.o -std=c++17 -DFLASHATTENTION_ENABLE_2D -g -O0 -DFLASH_DEBUG -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flashattn_2d_hopper_cuda -D_GLIBCXX_USE_CXX11_ABI=0
      /home/ubuntu/workspace/flash-attention/2d/src/flash_api.cpp: In function ‘void run_mha_fwd(Flash_fwd_params&, cudaStream_t, bool)’:
      /home/ubuntu/workspace/flash-attention/2d/src/flash_api.cpp:457:9: warning: variable ‘dtype’ set but not used [-Wunused-but-set-variable]
        457 |     int dtype = 1;
            |         ^~~~~
      [2/3] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /tmp/tmpo2nq5vjf.build-temp/src/flash_fwd_hdim64_bf16_gqa4_sm90.o.d -I/home/ubuntu/workspace/flash-attention/csrc/cutlass/include -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include/TH -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/ubuntu/anaconda3/envs/flash-dev/include/python3.11 -c -c /home/ubuntu/workspace/flash-attention/2d/src/flash_fwd_hdim64_bf16_gqa4_sm90.cu -o /tmp/tmpo2nq5vjf.build-temp/src/flash_fwd_hdim64_bf16_gqa4_sm90.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -std=c++17 -U__CUDA_NO_BFLOAT16_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ -U__CUDA_NO_BFLOAT162_OPERATORS__ -U__CUDA_NO_BFLOAT162_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math -DCUTLASS_DEBUG_TRACE_LEVEL=0 -DNDEBUG -DFLASHATTENTION_ENABLE_2D -gencode arch=compute_90a,code=sm_90a --threads 16 -g -O0 -G -DFLASH_DEBUG -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flashattn_2d_hopper_cuda -D_GLIBCXX_USE_CXX11_ABI=0
      Warning: Function too large, generated debug information may not be accurate.

      ptxas info    : (C7505) Potential Performance Loss: 'setmaxnreg' ignored to allow debugging.
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZN5flash15compute_attn_wsI23Flash_fwd_kernel_traitsILi64ELi128ELi128ELi12ELi2ELb0ELi1EN7cutlass10bfloat16_tELb0ELi4EELb0ELb1ENS_30DynamicPersistentTileSchedulerILi256ELi32ELb0EEENS_12SeqLenTraitsILb0ELb0ELb0EEENS7_ILb0ELb0ELb1EEEEEvNS_21CollectiveMainloopFwdIT_XT0_EXT1_ELb1ET3_T4_E6ParamsENS_21CollectiveEpilogueFwdISB_SD_E6ParamsENT2_6ParamsESD_SC_'
      ptxas info    : (C7511) Potential Performance Loss: wgmma.mma_async instructions are serialized due to insufficient register resources for the wgmma pipeline in the function '_ZN4cute28SM90_64x64x16_F32BF16BF16_RSILNS_4GMMA5MajorE0ELS2_1ELNS1_7ScaleInE1ELS3_1EE3fmaERKjS6_S6_S6_RKmRfS9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_NS1_8ScaleOutE'
      ptxas info    : (C7511) Potential Performance Loss: wgmma.mma_async instructions are serialized due to insufficient register resources for the wgmma pipeline in the function '_ZN4cute29SM90_64x128x16_F32BF16BF16_SSILNS_4GMMA5MajorE0ELS2_0ELNS1_7ScaleInE1ELS3_1EE3fmaERKmS6_RfS7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_NS1_8ScaleOutE'
      [3/3] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /tmp/tmpo2nq5vjf.build-temp/src/flash_bwd_hdim64_bf16_sm90.o.d -I/home/ubuntu/workspace/flash-attention/csrc/cutlass/include -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include/TH -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/ubuntu/anaconda3/envs/flash-dev/include/python3.11 -c -c /home/ubuntu/workspace/flash-attention/2d/src/flash_bwd_hdim64_bf16_sm90.cu -o /tmp/tmpo2nq5vjf.build-temp/src/flash_bwd_hdim64_bf16_sm90.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -std=c++17 -U__CUDA_NO_BFLOAT16_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ -U__CUDA_NO_BFLOAT162_OPERATORS__ -U__CUDA_NO_BFLOAT162_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math -DCUTLASS_DEBUG_TRACE_LEVEL=0 -DNDEBUG -DFLASHATTENTION_ENABLE_2D -gencode arch=compute_90a,code=sm_90a --threads 16 -g -O0 -G -DFLASH_DEBUG -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flashattn_2d_hopper_cuda -D_GLIBCXX_USE_CXX11_ABI=0
      FAILED: /tmp/tmpo2nq5vjf.build-temp/src/flash_bwd_hdim64_bf16_sm90.o
      /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /tmp/tmpo2nq5vjf.build-temp/src/flash_bwd_hdim64_bf16_sm90.o.d -I/home/ubuntu/workspace/flash-attention/csrc/cutlass/include -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include/TH -I/home/ubuntu/anaconda3/envs/flash-dev/lib/python3.11/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/ubuntu/anaconda3/envs/flash-dev/include/python3.11 -c -c /home/ubuntu/workspace/flash-attention/2d/src/flash_bwd_hdim64_bf16_sm90.cu -o /tmp/tmpo2nq5vjf.build-temp/src/flash_bwd_hdim64_bf16_sm90.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -std=c++17 -U__CUDA_NO_BFLOAT16_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ -U__CUDA_NO_BFLOAT162_OPERATORS__ -U__CUDA_NO_BFLOAT162_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math -DCUTLASS_DEBUG_TRACE_LEVEL=0 -DNDEBUG -DFLASHATTENTION_ENABLE_2D -gencode arch=compute_90a,code=sm_90a --threads 16 -g -O0 -G -DFLASH_DEBUG -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flashattn_2d_hopper_cuda -D_GLIBCXX_USE_CXX11_ABI=0
      Warning: Function too large, generated debug information may not be accurate.

      ptxas info    : (C7505) Potential Performance Loss: 'setmaxnreg' ignored to allow debugging.
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZN7cutlass13device_kernelIN5flash12FlashAttnBwdINS1_21CollectiveMainloopBwdILi2EN4cute5tupleIJNS4_1CILi1EEES7_S7_EEENS5_IJNS6_ILi128EEES9_NS6_ILi64EEEEEENS_10bfloat16_tEfNS_4arch4Sm90ELb0ELb1ELb1ELb1ELb0ELb0ELi1ELi2ELi2EEENS1_21CollectiveEpilogueBwdISB_SC_Li256ELb1EEENS1_22SingleTileSchedulerBwdEEEEEvNT_6ParamsE'
      ptxas info    : (C7505) Potential Performance Loss: 'setmaxnreg' ignored to allow debugging.
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZN7cutlass13device_kernelIN5flash12FlashAttnBwdINS1_21CollectiveMainloopBwdILi2EN4cute5tupleIJNS4_1CILi1EEES7_S7_EEENS5_IJNS6_ILi128EEES9_NS6_ILi64EEEEEENS_10bfloat16_tEfNS_4arch4Sm90ELb1ELb0ELb1ELb0ELb0ELb0ELi1ELi2ELi2EEENS1_21CollectiveEpilogueBwdISB_SC_Li256ELb1EEENS1_22SingleTileSchedulerBwdEEEEEvNT_6ParamsE'
      ptxas info    : (C7505) Potential Performance Loss: 'setmaxnreg' ignored to allow debugging.
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZN7cutlass13device_kernelIN5flash12FlashAttnBwdINS1_21CollectiveMainloopBwdILi2EN4cute5tupleIJNS4_1CILi1EEES7_S7_EEENS5_IJNS6_ILi128EEES9_NS6_ILi64EEEEEENS_10bfloat16_tEfNS_4arch4Sm90ELb0ELb0ELb1ELb1ELb0ELb0ELi1ELi2ELi2EEENS1_21CollectiveEpilogueBwdISB_SC_Li256ELb1EEENS1_22SingleTileSchedulerBwdEEEEEvNT_6ParamsE'
      ptxas info    : (C7505) Potential Performance Loss: 'setmaxnreg' ignored to allow debugging.
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZN7cutlass13device_kernelIN5flash12FlashAttnBwdINS1_21CollectiveMainloopBwdILi2EN4cute5tupleIJNS4_1CILi1EEES7_S7_EEENS5_IJNS6_ILi128EEES9_NS6_ILi64EEEEEENS_10bfloat16_tEfNS_4arch4Sm90ELb1ELb0ELb1ELb1ELb0ELb0ELi1ELi2ELi2EEENS1_21CollectiveEpilogueBwdISB_SC_Li256ELb1EEENS1_22SingleTileSchedulerBwdEEEEEvNT_6ParamsE'
      ptxas info    : (C7505) Potential Performance Loss: 'setmaxnreg' ignored to allow debugging.
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZN7cutlass13device_kernelIN5flash12FlashAttnBwdINS1_21CollectiveMainloopBwdILi2EN4cute5tupleIJNS4_1CILi1EEES7_S7_EEENS5_IJNS6_ILi128EEES9_NS6_ILi64EEEEEENS_10bfloat16_tEfNS_4arch4Sm90ELb0ELb1ELb0ELb1ELb0ELb0ELi1ELi2ELi2EEENS1_21CollectiveEpilogueBwdISB_SC_Li256ELb0EEENS1_22SingleTileSchedulerBwdEEEEEvNT_6ParamsE'
      ptxas info    : (C7505) Potential Performance Loss: 'setmaxnreg' ignored to allow debugging.
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZN7cutlass13device_kernelIN5flash12FlashAttnBwdINS1_21CollectiveMainloopBwdILi2EN4cute5tupleIJNS4_1CILi1EEES7_S7_EEENS5_IJNS6_ILi128EEES9_NS6_ILi64EEEEEENS_10bfloat16_tEfNS_4arch4Sm90ELb0ELb0ELb0ELb1ELb0ELb0ELi1ELi2ELi2EEENS1_21CollectiveEpilogueBwdISB_SC_Li256ELb0EEENS1_22SingleTileSchedulerBwdEEEEEvNT_6ParamsE'
      ptxas info    : (C7505) Potential Performance Loss: 'setmaxnreg' ignored to allow debugging.
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZN7cutlass13device_kernelIN5flash12FlashAttnBwdINS1_21CollectiveMainloopBwdILi2EN4cute5tupleIJNS4_1CILi1EEES7_S7_EEENS5_IJNS6_ILi128EEES9_NS6_ILi64EEEEEENS_10bfloat16_tEfNS_4arch4Sm90ELb1ELb0ELb0ELb1ELb0ELb0ELi1ELi2ELi2EEENS1_21CollectiveEpilogueBwdISB_SC_Li256ELb0EEENS1_22SingleTileSchedulerBwdEEEEEvNT_6ParamsE'
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZZN5flash21CollectiveMainloopBwdILi2EN4cute5tupleIJNS1_1CILi1EEES4_S4_EEENS2_IJNS3_ILi128EEES6_NS3_ILi64EEEEEEN7cutlass10bfloat16_tEfNS9_4arch4Sm90ELb0ELb1ELb0ELb0ELb0ELb0ELi1ELi2ELi2EE3mmaINS_12FlashAttnBwdISD_NS_21CollectiveEpilogueBwdIS8_SA_Li256ELb0EEENS_22SingleTileSchedulerBwdEE13SharedStorageENS1_6TensorINS1_11ArrayEngineIfLi32EEENS1_6LayoutINS2_IJNS2_IJNS3_ILi2EEESP_NS3_ILi8EEEEEES4_S4_EEENS2_IJNS2_IJS4_SP_NS3_ILi4EEEEEENS3_ILi0EEESV_EEEEEEEEEvRKNSD_6ParamsENS9_16PipelineTmaAsyncILi2EEES13_RNS9_13PipelineStateILj2EEERT0_S18_iiNS2_IJiiiEEERT_ENKUlvE_clEv'
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZZN5flash21CollectiveMainloopBwdILi2EN4cute5tupleIJNS1_1CILi1EEES4_S4_EEENS2_IJNS3_ILi128EEES6_NS3_ILi64EEEEEEN7cutlass10bfloat16_tEfNS9_4arch4Sm90ELb0ELb0ELb0ELb0ELb0ELb0ELi1ELi2ELi2EE3mmaINS_12FlashAttnBwdISD_NS_21CollectiveEpilogueBwdIS8_SA_Li256ELb0EEENS_22SingleTileSchedulerBwdEE13SharedStorageENS1_6TensorINS1_11ArrayEngineIfLi32EEENS1_6LayoutINS2_IJNS2_IJNS3_ILi2EEESP_NS3_ILi8EEEEEES4_S4_EEENS2_IJNS2_IJS4_SP_NS3_ILi4EEEEEENS3_ILi0EEESV_EEEEEEEEEvRKNSD_6ParamsENS9_16PipelineTmaAsyncILi2EEES13_RNS9_13PipelineStateILj2EEERT0_S18_iiNS2_IJiiiEEERT_ENKUlvE_clEv'
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZZN5flash21CollectiveMainloopBwdILi2EN4cute5tupleIJNS1_1CILi1EEES4_S4_EEENS2_IJNS3_ILi128EEES6_NS3_ILi64EEEEEEN7cutlass10bfloat16_tEfNS9_4arch4Sm90ELb1ELb0ELb0ELb0ELb0ELb0ELi1ELi2ELi2EE3mmaINS_12FlashAttnBwdISD_NS_21CollectiveEpilogueBwdIS8_SA_Li256ELb0EEENS_22SingleTileSchedulerBwdEE13SharedStorageENS1_6TensorINS1_11ArrayEngineIfLi32EEENS1_6LayoutINS2_IJNS2_IJNS3_ILi2EEESP_NS3_ILi8EEEEEES4_S4_EEENS2_IJNS2_IJS4_SP_NS3_ILi4EEEEEENS3_ILi0EEESV_EEEEEEEEEvRKNSD_6ParamsENS9_16PipelineTmaAsyncILi2EEES13_RNS9_13PipelineStateILj2EEERT0_S18_iiNS2_IJiiiEEERT_ENKUlvE_clEv'
      ptxas info    : (C7505) Potential Performance Loss: 'setmaxnreg' ignored to allow debugging.
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZN7cutlass13device_kernelIN5flash12FlashAttnBwdINS1_21CollectiveMainloopBwdILi2EN4cute5tupleIJNS4_1CILi1EEES7_S7_EEENS5_IJNS6_ILi128EEES9_NS6_ILi64EEEEEENS_10bfloat16_tEfNS_4arch4Sm90ELb0ELb1ELb1ELb0ELb0ELb0ELi1ELi2ELi2EEENS1_21CollectiveEpilogueBwdISB_SC_Li256ELb1EEENS1_22SingleTileSchedulerBwdEEEEEvNT_6ParamsE'
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZZN5flash21CollectiveMainloopBwdILi2EN4cute5tupleIJNS1_1CILi1EEES4_S4_EEENS2_IJNS3_ILi128EEES6_NS3_ILi64EEEEEEN7cutlass10bfloat16_tEfNS9_4arch4Sm90ELb0ELb0ELb1ELb0ELb0ELb0ELi1ELi2ELi2EE3mmaINS_12FlashAttnBwdISD_NS_21CollectiveEpilogueBwdIS8_SA_Li256ELb1EEENS_22SingleTileSchedulerBwdEE13SharedStorageENS1_6TensorINS1_11ArrayEngineIfLi32EEENS1_6LayoutINS2_IJNS2_IJNS3_ILi2EEESP_NS3_ILi8EEEEEES4_S4_EEENS2_IJNS2_IJS4_SP_NS3_ILi4EEEEEENS3_ILi0EEESV_EEEEEEEEEvRKNSD_6ParamsENS9_16PipelineTmaAsyncILi2EEES13_RNS9_13PipelineStateILj2EEERT0_S18_iiNS2_IJiiiEEERT_ENKUlvE_clEv'
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZZN5flash21CollectiveMainloopBwdILi2EN4cute5tupleIJNS1_1CILi1EEES4_S4_EEENS2_IJNS3_ILi128EEES6_NS3_ILi64EEEEEEN7cutlass10bfloat16_tEfNS9_4arch4Sm90ELb1ELb0ELb1ELb0ELb0ELb0ELi1ELi2ELi2EE3mmaINS_12FlashAttnBwdISD_NS_21CollectiveEpilogueBwdIS8_SA_Li256ELb1EEENS_22SingleTileSchedulerBwdEE13SharedStorageENS1_6TensorINS1_11ArrayEngineIfLi32EEENS1_6LayoutINS2_IJNS2_IJNS3_ILi2EEESP_NS3_ILi8EEEEEES4_S4_EEENS2_IJNS2_IJS4_SP_NS3_ILi4EEEEEENS3_ILi0EEESV_EEEEEEEEEvRKNSD_6ParamsENS9_16PipelineTmaAsyncILi2EEES13_RNS9_13PipelineStateILj2EEERT0_S18_iiNS2_IJiiiEEERT_ENKUlvE_clEv'
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZZN5flash21CollectiveMainloopBwdILi2EN4cute5tupleIJNS1_1CILi1EEES4_S4_EEENS2_IJNS3_ILi128EEES6_NS3_ILi64EEEEEEN7cutlass10bfloat16_tEfNS9_4arch4Sm90ELb0ELb1ELb1ELb0ELb0ELb0ELi1ELi2ELi2EE3mmaINS_12FlashAttnBwdISD_NS_21CollectiveEpilogueBwdIS8_SA_Li256ELb1EEENS_22SingleTileSchedulerBwdEE13SharedStorageENS1_6TensorINS1_11ArrayEngineIfLi32EEENS1_6LayoutINS2_IJNS2_IJNS3_ILi2EEESP_NS3_ILi8EEEEEES4_S4_EEENS2_IJNS2_IJS4_SP_NS3_ILi4EEEEEENS3_ILi0EEESV_EEEEEEEEEvRKNSD_6ParamsENS9_16PipelineTmaAsyncILi2EEES13_RNS9_13PipelineStateILj2EEERT0_S18_iiNS2_IJiiiEEERT_ENKUlvE_clEv'
      ptxas info    : (C7505) Potential Performance Loss: 'setmaxnreg' ignored to allow debugging.
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZN7cutlass13device_kernelIN5flash12FlashAttnBwdINS1_21CollectiveMainloopBwdILi2EN4cute5tupleIJNS4_1CILi1EEES7_S7_EEENS5_IJNS6_ILi128EEES9_NS6_ILi64EEEEEENS_10bfloat16_tEfNS_4arch4Sm90ELb0ELb0ELb1ELb0ELb0ELb0ELi1ELi2ELi2EEENS1_21CollectiveEpilogueBwdISB_SC_Li256ELb1EEENS1_22SingleTileSchedulerBwdEEEEEvNT_6ParamsE'
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZZN5flash21CollectiveMainloopBwdILi2EN4cute5tupleIJNS1_1CILi1EEES4_S4_EEENS2_IJNS3_ILi128EEES6_NS3_ILi64EEEEEEN7cutlass10bfloat16_tEfNS9_4arch4Sm90ELb0ELb0ELb1ELb1ELb0ELb0ELi1ELi2ELi2EE3mmaINS_12FlashAttnBwdISD_NS_21CollectiveEpilogueBwdIS8_SA_Li256ELb1EEENS_22SingleTileSchedulerBwdEE13SharedStorageENS1_6TensorINS1_11ArrayEngineIfLi32EEENS1_6LayoutINS2_IJNS2_IJNS3_ILi2EEESP_NS3_ILi8EEEEEES4_S4_EEENS2_IJNS2_IJS4_SP_NS3_ILi4EEEEEENS3_ILi0EEESV_EEEEEEEEEvRKNSD_6ParamsENS9_16PipelineTmaAsyncILi2EEES13_RNS9_13PipelineStateILj2EEERT0_S18_iiNS2_IJiiiEEERT_ENKUlvE_clEv'
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZZN5flash21CollectiveMainloopBwdILi2EN4cute5tupleIJNS1_1CILi1EEES4_S4_EEENS2_IJNS3_ILi128EEES6_NS3_ILi64EEEEEEN7cutlass10bfloat16_tEfNS9_4arch4Sm90ELb1ELb0ELb1ELb1ELb0ELb0ELi1ELi2ELi2EE3mmaINS_12FlashAttnBwdISD_NS_21CollectiveEpilogueBwdIS8_SA_Li256ELb1EEENS_22SingleTileSchedulerBwdEE13SharedStorageENS1_6TensorINS1_11ArrayEngineIfLi32EEENS1_6LayoutINS2_IJNS2_IJNS3_ILi2EEESP_NS3_ILi8EEEEEES4_S4_EEENS2_IJNS2_IJS4_SP_NS3_ILi4EEEEEENS3_ILi0EEESV_EEEEEEEEEvRKNSD_6ParamsENS9_16PipelineTmaAsyncILi2EEES13_RNS9_13PipelineStateILj2EEERT0_S18_iiNS2_IJiiiEEERT_ENKUlvE_clEv'
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZZN5flash21CollectiveMainloopBwdILi2EN4cute5tupleIJNS1_1CILi1EEES4_S4_EEENS2_IJNS3_ILi128EEES6_NS3_ILi64EEEEEEN7cutlass10bfloat16_tEfNS9_4arch4Sm90ELb0ELb1ELb1ELb1ELb0ELb0ELi1ELi2ELi2EE3mmaINS_12FlashAttnBwdISD_NS_21CollectiveEpilogueBwdIS8_SA_Li256ELb1EEENS_22SingleTileSchedulerBwdEE13SharedStorageENS1_6TensorINS1_11ArrayEngineIfLi32EEENS1_6LayoutINS2_IJNS2_IJNS3_ILi2EEESP_NS3_ILi8EEEEEES4_S4_EEENS2_IJNS2_IJS4_SP_NS3_ILi4EEEEEENS3_ILi0EEESV_EEEEEEEEEvRKNSD_6ParamsENS9_16PipelineTmaAsyncILi2EEES13_RNS9_13PipelineStateILj2EEERT0_S18_iiNS2_IJiiiEEERT_ENKUlvE_clEv'
      ptxas info    : (C7505) Potential Performance Loss: 'setmaxnreg' ignored to allow debugging.
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZN7cutlass13device_kernelIN5flash12FlashAttnBwdINS1_21CollectiveMainloopBwdILi2EN4cute5tupleIJNS4_1CILi1EEES7_S7_EEENS5_IJNS6_ILi128EEES9_NS6_ILi64EEEEEENS_10bfloat16_tEfNS_4arch4Sm90ELb0ELb1ELb0ELb0ELb0ELb0ELi1ELi2ELi2EEENS1_21CollectiveEpilogueBwdISB_SC_Li256ELb0EEENS1_22SingleTileSchedulerBwdEEEEEvNT_6ParamsE'
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZZN5flash21CollectiveMainloopBwdILi2EN4cute5tupleIJNS1_1CILi1EEES4_S4_EEENS2_IJNS3_ILi128EEES6_NS3_ILi64EEEEEEN7cutlass10bfloat16_tEfNS9_4arch4Sm90ELb0ELb1ELb0ELb1ELb0ELb0ELi1ELi2ELi2EE3mmaINS_12FlashAttnBwdISD_NS_21CollectiveEpilogueBwdIS8_SA_Li256ELb0EEENS_22SingleTileSchedulerBwdEE13SharedStorageENS1_6TensorINS1_11ArrayEngineIfLi32EEENS1_6LayoutINS2_IJNS2_IJNS3_ILi2EEESP_NS3_ILi8EEEEEES4_S4_EEENS2_IJNS2_IJS4_SP_NS3_ILi4EEEEEENS3_ILi0EEESV_EEEEEEEEEvRKNSD_6ParamsENS9_16PipelineTmaAsyncILi2EEES13_RNS9_13PipelineStateILj2EEERT0_S18_iiNS2_IJiiiEEERT_ENKUlvE_clEv'
      ptxas info    : (C7505) Potential Performance Loss: 'setmaxnreg' ignored to allow debugging.
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZN7cutlass13device_kernelIN5flash12FlashAttnBwdINS1_21CollectiveMainloopBwdILi2EN4cute5tupleIJNS4_1CILi1EEES7_S7_EEENS5_IJNS6_ILi128EEES9_NS6_ILi64EEEEEENS_10bfloat16_tEfNS_4arch4Sm90ELb0ELb0ELb0ELb0ELb0ELb0ELi1ELi2ELi2EEENS1_21CollectiveEpilogueBwdISB_SC_Li256ELb0EEENS1_22SingleTileSchedulerBwdEEEEEvNT_6ParamsE'
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZZN5flash21CollectiveMainloopBwdILi2EN4cute5tupleIJNS1_1CILi1EEES4_S4_EEENS2_IJNS3_ILi128EEES6_NS3_ILi64EEEEEEN7cutlass10bfloat16_tEfNS9_4arch4Sm90ELb0ELb0ELb0ELb1ELb0ELb0ELi1ELi2ELi2EE3mmaINS_12FlashAttnBwdISD_NS_21CollectiveEpilogueBwdIS8_SA_Li256ELb0EEENS_22SingleTileSchedulerBwdEE13SharedStorageENS1_6TensorINS1_11ArrayEngineIfLi32EEENS1_6LayoutINS2_IJNS2_IJNS3_ILi2EEESP_NS3_ILi8EEEEEES4_S4_EEENS2_IJNS2_IJS4_SP_NS3_ILi4EEEEEENS3_ILi0EEESV_EEEEEEEEEvRKNSD_6ParamsENS9_16PipelineTmaAsyncILi2EEES13_RNS9_13PipelineStateILj2EEERT0_S18_iiNS2_IJiiiEEERT_ENKUlvE_clEv'
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZZN5flash21CollectiveMainloopBwdILi2EN4cute5tupleIJNS1_1CILi1EEES4_S4_EEENS2_IJNS3_ILi128EEES6_NS3_ILi64EEEEEEN7cutlass10bfloat16_tEfNS9_4arch4Sm90ELb1ELb0ELb0ELb1ELb0ELb0ELi1ELi2ELi2EE3mmaINS_12FlashAttnBwdISD_NS_21CollectiveEpilogueBwdIS8_SA_Li256ELb0EEENS_22SingleTileSchedulerBwdEE13SharedStorageENS1_6TensorINS1_11ArrayEngineIfLi32EEENS1_6LayoutINS2_IJNS2_IJNS3_ILi2EEESP_NS3_ILi8EEEEEES4_S4_EEENS2_IJNS2_IJS4_SP_NS3_ILi4EEEEEENS3_ILi0EEESV_EEEEEEEEEvRKNSD_6ParamsENS9_16PipelineTmaAsyncILi2EEES13_RNS9_13PipelineStateILj2EEERT0_S18_iiNS2_IJiiiEEERT_ENKUlvE_clEv'
      ptxas info    : (C7505) Potential Performance Loss: 'setmaxnreg' ignored to allow debugging.
      ptxas info    : (C7509) Potential Performance Loss: wgmma.mma_async instructions are serialized due to the presence of Extern calls in the function '_ZN7cutlass13device_kernelIN5flash12FlashAttnBwdINS1_21CollectiveMainloopBwdILi2EN4cute5tupleIJNS4_1CILi1EEES7_S7_EEENS5_IJNS6_ILi128EEES9_NS6_ILi64EEEEEENS_10bfloat16_tEfNS_4arch4Sm90ELb1ELb0ELb0ELb0ELb0ELb0ELi1ELi2ELi2EEENS1_21CollectiveEpilogueBwdISB_SC_Li256ELb0EEENS1_22SingleTileSchedulerBwdEEEEEvNT_6ParamsE'
      ptxas info    : (C7511) Potential Performance Loss: wgmma.mma_async instructions are serialized due to insufficient register resources for the wgmma pipeline in the function '_ZN4cute28SM90_64x64x16_F32BF16BF16_RSILNS_4GMMA5MajorE0ELS2_1ELNS1_7ScaleInE1ELS3_1EE3fmaERKjS6_S6_S6_RKmRfS9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_S9_NS1_8ScaleOutE'
      ptxas info    : (C7511) Potential Performance Loss: wgmma.mma_async instructions are serialized due to insufficient register resources for the wgmma pipeline in the function '_ZN4cute29SM90_64x128x16_F32BF16BF16_SSILNS_4GMMA5MajorE0ELS2_0ELNS1_7ScaleInE1ELS3_1EE3fmaERKmS6_RfS7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_NS1_8ScaleOutE'
      ptxas info    : (C7511) Potential Performance Loss: wgmma.mma_async instructions are serialized due to insufficient register resources for the wgmma pipeline in the function '_ZN4cute28SM90_64x64x16_F32BF16BF16_SSILNS_4GMMA5MajorE0ELS2_1ELNS1_7ScaleInE1ELS3_1EE3fmaERKmS6_RfS7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_S7_NS1_8ScaleOutE'
      double free or corruption (out)
      Aborted (core dumped)

Above is printed log due to compilation failure after few hours.(Don't know exactly how long due to afk)
Compilation takes all of the ram (2TiB) before pruning templates, but after the workaround I've mentioned above it consumes under 100G so I guess it's not a OOM issue.

@Dev-Jahn Dev-Jahn changed the title Is there any way to compile the codes with nvcc debug flag(-G)? Is there any way to compile the codes with nvcc debug flag(-G) faster? Dec 2, 2024
@Dev-Jahn Dev-Jahn changed the title Is there any way to compile the codes with nvcc debug flag(-G) faster? Is there any way to compile the codes with nvcc debug flag(-G)? Dec 3, 2024
@tridao
Copy link
Contributor

tridao commented Dec 3, 2024

You can try the decode branch. We recently added some env vars to enable/disable features.
E.g. you can set these to disable most of the features:

FLASH_ATTENTION_DISABLE_BACKWARD=TRUE
FLASH_ATTENTION_DISABLE_SPLIT=TRUE
FLASH_ATTENTION_DISABLE_LOCAL=TRUE
FLASH_ATTENTION_DISABLE_PAGEDKV=TRUE
FLASH_ATTENTION_DISABLE_FP16=TRUE
FLASH_ATTENTION_DISABLE_FP8=TRUE
FLASH_ATTENTION_DISABLE_APPENDKV=TRUE
FLASH_ATTENTION_DISABLE_VARLEN=TRUE
FLASH_ATTENTION_DISABLE_CLUSTER=TRUE
FLASH_ATTENTION_DISABLE_PACKGQA=TRUE
FLASH_ATTENTION_DISABLE_SOFTCAP=TRUE

@tridao
Copy link
Contributor

tridao commented Dec 3, 2024

Btw I always just print things out to debug

@Dev-Jahn
Copy link
Author

Dev-Jahn commented Dec 4, 2024

Btw I always just print things out to debug

Thx for the answer.
Maybe I was unnecessarily struggling to use a overengineered tool.😂

@miaomiaoma0703
Copy link

I 'm trying to compile the debug version of FlashAttention 2 (on A100 cuda12.4 pytorch 2.4), but the compilation failed due to OOM. The machine has 1 TiB RAM and 64 cores.
setup.py:

extra_compile_args={
    "cxx": ["-O0", "-g", "-std=c++17"] + generator_flag,
    "nvcc": append_nvcc_threads(
        [
            "-O0",
            "-G",
            "-g",
            "-std=c++17",
            "-U__CUDA_NO_HALF_OPERATORS__",
            "-U__CUDA_NO_HALF_CONVERSIONS__",
            "-U__CUDA_NO_HALF2_OPERATORS__",
            "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
            "--expt-relaxed-constexpr",
            "--expt-extended-lambda",
            "--use_fast_math",
            # "--ptxas-options=-v",
            # "--ptxas-options=-O2",
            # "-lineinfo",
            # "-DFLASHATTENTION_DISABLE_BACKWARD",
            # "-DFLASHATTENTION_DISABLE_DROPOUT",
            # "-DFLASHATTENTION_DISABLE_ALIBI",
            # "-DFLASHATTENTION_DISABLE_SOFTCAP",
            # "-DFLASHATTENTION_DISABLE_UNEVEN_K",
            # "-DFLASHATTENTION_DISABLE_LOCAL",
        ]
        + generator_flag
        + cc_flag
    ),
},

I tried to set max_jobs=64/16/8/4, all above OOM and killed, I set max_jobs = 1, compiIation error(after more 3 hours):

raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v', '-j', '1']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/miaomiao.ma/LLM/flash-attention/flash-attention-debug/setup.py", line 512, in
setup(
File "/usr/lib/python3/dist-packages/setuptools/init.py", line 153, in setup
return distutils.core.setup(**attrs)
File "/usr/lib/python3.10/distutils/core.py", line 148, in setup
dist.run_commands()
File "/usr/lib/python3.10/distutils/dist.py", line 966, in run_commands
self.run_command(cmd)
File "/usr/lib/python3.10/distutils/dist.py", line 985, in run_command
cmd_obj.run()
File "/usr/lib/python3/dist-packages/setuptools/command/install.py", line 74, in run
self.do_egg_install()
File "/usr/lib/python3/dist-packages/setuptools/command/install.py", line 116, in do_egg_install
self.run_command('bdist_egg')
File "/usr/lib/python3.10/distutils/cmd.py", line 313, in run_command
self.distribution.run_command(command)
File "/usr/lib/python3.10/distutils/dist.py", line 985, in run_command
cmd_obj.run()
File "/usr/lib/python3/dist-packages/setuptools/command/bdist_egg.py", line 164, in run
cmd = self.call_command('install_lib', warn_dir=0)
File "/usr/lib/python3/dist-packages/setuptools/command/bdist_egg.py", line 150, in call_command
self.run_command(cmdname)
File "/usr/lib/python3.10/distutils/cmd.py", line 313, in run_command
self.distribution.run_command(command)
File "/usr/lib/python3.10/distutils/dist.py", line 985, in run_command
cmd_obj.run()
File "/usr/lib/python3/dist-packages/setuptools/command/install_lib.py", line 23, in run
self.build()
File "/usr/lib/python3.10/distutils/command/install_lib.py", line 109, in build
self.run_command('build_ext')
File "/usr/lib/python3.10/distutils/cmd.py", line 313, in run_command
self.distribution.run_command(command)
File "/usr/lib/python3.10/distutils/dist.py", line 985, in run_command
cmd_obj.run()
File "/usr/lib/python3/dist-packages/setuptools/command/build_ext.py", line 79, in run
_build_ext.run(self)
File "/usr/lib/python3.10/distutils/command/build_ext.py", line 340, in run
self.build_extensions()
File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 866, in build_extensions
build_ext.build_extensions(self)
File "/usr/lib/python3.10/distutils/command/build_ext.py", line 449, in build_extensions
self._build_extensions_serial()
File "/usr/lib/python3.10/distutils/command/build_ext.py", line 474, in _build_extensions_serial
self.build_extension(ext)
File "/usr/lib/python3/dist-packages/setuptools/command/build_ext.py", line 202, in build_extension
_build_ext.build_extension(self, ext)
File "/usr/lib/python3.10/distutils/command/build_ext.py", line 529, in build_extension
objects = self.compiler.compile(sources,
File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 679, in unix_wrap_ninja_compile
_write_ninja_file_and_compile_objects(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 1785, in _write_ninja_file_and_compile_objects
_run_ninja_build(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 2121, in _run_ninja_build
raise RuntimeError(message) from e
RuntimeError: Error compiling objects for extension

I set max_jobs = 2, the memory is still not enough, and this process takes a long time.
I'm new to compile c++ and cuda. How to successfully compile the debug version?

@Dev-Jahn
Copy link
Author

Dev-Jahn commented Dec 4, 2024

@miaomiaoma0703 After few hours of workaround, I've managed to compile the code without an error.
I've observed single ptx intermediate file takes few GBs and fact that single source is too large may be the reason of failing.
Do not try to increase nvcc --threads or MAX_JOBS. It does nothing with the process of single cicc and ptxas compilation.
Just explicitly instantiate EVERY template parameters only for your debug case and remove all others(As I did in first comment).
It may defer from my fa-3 case, but I took about 18 mins to compile single cu file for forward pass.
But I'm not sure it was worth it. Just using printf would be easy way as Prof. Dao mentioned.

@miaomiaoma0703
Copy link

@miaomiaoma0703 After few hours of workaround, I've managed to compile the code without an error. I've observed single ptx intermediate file takes few GBs and fact that single source is too large may be the reason of failing. Do not try to increase nvcc --threads or MAX_JOBS. It does nothing with the process of single cicc and ptxas compilation. Just explicitly instantiate EVERY template parameters only for your debug case and remove all others(As I did in first comment). It may defer from my fa-3 case, but I took about 18 mins to compile single cu file for forward pass. But I'm not sure it was worth it. Just using printf would be easy way as Prof. Dao mentioned.

@Dev-Jahn Thank you very much! Perhaps because single ptx intermediate file takes few GBs and 1 TiB RAM is too small to compile successfully. I has compiled the release version successfully after 15 mins, and I plan to print things out to debug like the auther, or like you, just explicitly instantiate EVERY template parameters only for my debug case and remove all others(As you did in first comment) in debug version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants