diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index b747c8d60..f081a99f1 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -8,7 +8,7 @@ #include #include #include // For at::Generator and at::PhiloxCudaState -#include // For at::cuda::philox::unpack +#include "philox_unpack.cuh" // For at::cuda::philox::unpack #include diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 30217f3f6..a74de974a 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -4,7 +4,7 @@ #pragma once -#include // For at::cuda::philox::unpack +#include "philox_unpack.cuh" // For at::cuda::philox::unpack #include diff --git a/csrc/flash_attn/src/philox_unpack.cuh b/csrc/flash_attn/src/philox_unpack.cuh new file mode 100644 index 000000000..3a54f45cb --- /dev/null +++ b/csrc/flash_attn/src/philox_unpack.cuh @@ -0,0 +1,4 @@ +// This is purely so that it works with torch 2.1. For torch 2.2+ we can include ATen/cuda/PhiloxUtils.cuh + +#pragma once +#include diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index e3c0e8e04..576b24896 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.7.1.post3" +__version__ = "2.7.1.post4" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/setup.py b/setup.py index c2126134a..d7d16b642 100644 --- a/setup.py +++ b/setup.py @@ -114,7 +114,7 @@ def check_if_rocm_home_none(global_option: str) -> None: def append_nvcc_threads(nvcc_extra_args): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" + nvcc_threads = os.getenv("NVCC_THREADS") or "2" return nvcc_extra_args + ["--threads", nvcc_threads]