Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix build when flash attention and memory efficient attention are disabled #18761

Merged
merged 11 commits into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov
option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF)
option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)

cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cutlass support" ON "onnxruntime_USE_CUDA" OFF)
cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF)
option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)

Expand Down Expand Up @@ -693,16 +694,20 @@ if (onnxruntime_USE_CUDA)
enable_language(CUDA)
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")

if (onnxruntime_DISABLE_CONTRIB_OPS)
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
message( STATUS "Turn off cutlass since CUDA compiler version < 11.6")
set(onnxruntime_USE_CUTLASS OFF)
endif()
else()
set(onnxruntime_USE_CUTLASS OFF)
endif()

if (NOT onnxruntime_USE_CUTLASS OR onnxruntime_DISABLE_CONTRIB_OPS)
if (onnxruntime_DISABLE_CONTRIB_OPS)
message( STATUS "Turn off flash attention/memory efficient attention since contrib ops are disabled")
else()
message( STATUS "Turn off flash attention/memory efficient attention since cutlass is not enabled")
endif()
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
Expand Down Expand Up @@ -887,6 +892,11 @@ function(onnxruntime_set_compile_flags target_name)
if (onnxruntime_ENABLE_ATEN)
target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN)
endif()

if (onnxruntime_USE_CUTLASS)
target_compile_definitions(${target_name} PRIVATE USE_CUTLASS)
endif()

set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON)
if (onnxruntime_USE_CUDA)
# Suppress a "conversion_function_not_usable" warning in gsl/span
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION)
if (onnxruntime_USE_CUTLASS)
include(FetchContent)
FetchContent_Declare(
cutlass
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
Expand Down Expand Up @@ -202,3 +204,5 @@ Status ShardedMoE<T>::SynchronizeExpertsStartIndex(AllocatorPtr& allocator,
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/sharded_moe.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#pragma once

#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h"
Expand Down Expand Up @@ -34,3 +36,5 @@ class ShardedMoE final : public NcclKernel, public MoEBase {
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

#endif
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop);
#ifdef USE_CUTLASS
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE);
#endif
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention);
Expand Down Expand Up @@ -165,8 +167,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllR
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll);

#ifdef USE_CUTLASS
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE);
#endif

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul);
Expand Down Expand Up @@ -266,8 +270,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop)>,
#ifdef USE_CUTLASS
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE)>,
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention)>,
Expand Down Expand Up @@ -367,8 +373,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll)>,

#ifdef USE_CUTLASS
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE)>,
#endif

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul)>,
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifdef USE_CUTLASS

#pragma once

#include <cuda_runtime_api.h>
Expand Down Expand Up @@ -49,3 +52,5 @@ inline int compute_occupancy_for_kernel() {
}

} // namespace ort_fastertransformer

#endif
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef USE_CUTLASS

#include "cutlass_heuristic.h"

Expand Down Expand Up @@ -185,3 +186,5 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
}

} // namespace ort_fastertransformer

#endif
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef USE_CUTLASS

#pragma once

Expand All @@ -37,3 +38,4 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
const int multi_processor_count, const int is_weight_only);

} // namespace ort_fastertransformer
#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
*
*/

#ifdef USE_CUTLASS

#pragma once

#include "cutlass/array.h"
Expand Down Expand Up @@ -131,3 +133,5 @@ struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, Epilog
};

} // namespace ort_fastertransformer

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

#pragma once

namespace ort_fastertransformer {
Expand Down Expand Up @@ -56,3 +58,5 @@ struct CutlassGemmConfig {
};

} // namespace ort_fastertransformer

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
*
**************************************************************************************************/

#ifdef USE_CUTLASS

/*! \file
\brief Scheduler for grouped GEMM
*/
Expand Down Expand Up @@ -77,3 +79,5 @@ struct GemmMoeProblemVisitor
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

*/

#ifdef USE_CUTLASS

#pragma once

#include "cutlass/layout/matrix.h"
Expand Down Expand Up @@ -150,4 +152,6 @@ struct MixedGemmArchTraits<

} // namespace kernel
} // namespace gemm
} // namespace cutlass
} // namespace cutlass

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
*
**************************************************************************************************/

#ifdef USE_CUTLASS

#pragma once

#include "cutlass/complex.h"
Expand Down Expand Up @@ -461,3 +463,5 @@ struct MoeFCGemm {
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

#pragma once

#include <cuda_runtime_api.h>
Expand Down Expand Up @@ -62,3 +64,5 @@ class MoeGemmRunner {
};

} // namespace ort_fastertransformer

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

#include "moe_gemm_kernels_template.h"

namespace ort_fastertransformer {
template class MoeGemmRunner<half, half>;
} // namespace ort_fastertransformer

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

#include "moe_gemm_kernels_template.h"

namespace ort_fastertransformer {
template class MoeGemmRunner<float, float>;
} // namespace ort_fastertransformer

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#ifdef USE_CUTLASS

// Ignore CUTLASS warnings about type punning
#ifdef __GNUC__
#pragma GCC diagnostic push
Expand Down Expand Up @@ -426,3 +428,5 @@ void MoeGemmRunner<T, WeightType>::moe_gemm(const T* A, const WeightType* B, con
}

} // namespace ort_fastertransformer

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#include <cuda.h>
#include <cuda_fp16.h>
#include <math.h>
Expand Down Expand Up @@ -898,3 +900,5 @@ template void finalize_moe_routing_kernelLauncher(const half*, half*, const half
cudaStream_t);

} // namespace ort_fastertransformer

#endif
6 changes: 5 additions & 1 deletion onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#pragma once

#include "moe_gemm_kernels.h"
Expand Down Expand Up @@ -172,4 +174,6 @@ class CutlassMoeFCRunner<float, WeightType, typename std::enable_if_t<!std::is_s
}
};

} // namespace ort_fastertransformer
} // namespace ort_fastertransformer

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
\brief Base scheduler for grouped problems, using MoE
*/

#ifdef USE_CUTLASS

#pragma once

#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
Expand Down Expand Up @@ -288,3 +290,5 @@ struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode:
} // namespace kernel
} // namespace gemm
} // namespace cutlass

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
/*! \file
\brief Defines new layouts needed for MoE
*/

#ifdef USE_CUTLASS

#pragma once

#include "cutlass/cutlass.h"
Expand Down Expand Up @@ -59,3 +62,5 @@ struct IsColumnMajorTileInterleave<ColumnMajorTileInterleave<U, V>> {

} // namespace layout
} // namespace cutlass

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/moe.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
#include "moe.h"
Expand Down Expand Up @@ -117,3 +119,5 @@ Status MoE<T>::ComputeInternal(OpKernelContext* context) const {
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

#endif
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/moe.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUTLASS

#pragma once

#include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h"
Expand All @@ -24,3 +26,5 @@ class MoE final : public CudaKernel, public MoEBase {
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

#endif
Loading
Loading