From de46ba5807b8c779e911de4987396c7e151f3273 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Fri, 8 Dec 2023 10:29:00 +0000 Subject: [PATCH 1/9] fix --- onnxruntime/contrib_ops/cuda/math/gemm_float8.cu | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index 56b541f5256bf..a22e1dbae3f4f 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -251,15 +251,21 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &ctransb, sizeof(ctransb))); +#if CUDA_VERSION >= 11060 + // CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET exists from https://docs.nvidia.com/cuda/archive/11.6.0/pdf/CUBLAS_Library.pdf if (sm_count_ != 0) { int math_sm_count = static_cast(sm_count_); CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, sizeof(math_sm_count))); } +#endif if (has_scales) { // gemm float 8 +#if CUDA_VERSION >= 11080 + // CUBLASLT_MATMUL_DESC_FAST_ACCUM, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + // CUBLASLT_MATMUL_DESC_D_SCALE_POINTER exist from https://docs.nvidia.com/cuda/archive/11.8.0/pdf/CUBLAS_Library.pdf const int8_t ifast_accumulation_mode = 1; CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, @@ -274,6 +280,7 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y, sizeof(p_scale_b))); +#endif // float 8 #if !defined(DISABLE_FLOAT8_TYPES) From a200264eee4838984d49aade5f62a8e42d45e3ee Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Fri, 8 Dec 2023 10:31:36 +0000 Subject: [PATCH 2/9] format --- onnxruntime/contrib_ops/cuda/math/gemm_float8.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index a22e1dbae3f4f..064b6dd392437 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -280,7 +280,7 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y, sizeof(p_scale_b))); -#endif +#endif // float 8 #if !defined(DISABLE_FLOAT8_TYPES) From 3d78b1756608b0469c321275d1c4e6ec14ee2c51 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Fri, 8 Dec 2023 12:51:55 +0000 Subject: [PATCH 3/9] fixes --- cmake/CMakeLists.txt | 8 ++++++++ cmake/external/cutlass.cmake | 2 +- onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc | 4 ++++ onnxruntime/contrib_ops/cuda/collective/sharded_moe.h | 4 ++++ onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc | 8 ++++++++ .../contrib_ops/cuda/moe/ft_moe/compute_occupancy.h | 5 +++++ .../contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc | 3 +++ .../contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.h | 2 ++ .../contrib_ops/cuda/moe/ft_moe/epilogue_helpers.h | 4 ++++ onnxruntime/contrib_ops/cuda/moe/ft_moe/ft_gemm_configs.h | 4 ++++ .../cuda/moe/ft_moe/gemm_moe_problem_visitor.h | 4 ++++ .../contrib_ops/cuda/moe/ft_moe/layout_traits_helper.h | 6 +++++- .../contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h | 4 ++++ .../contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h | 4 ++++ .../cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu | 4 ++++ .../cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu | 4 ++++ .../cuda/moe/ft_moe/moe_gemm_kernels_template.h | 4 ++++ onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 4 ++++ onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h | 6 +++++- .../contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h | 4 ++++ .../contrib_ops/cuda/moe/ft_moe/tile_interleaved_layout.h | 5 +++++ onnxruntime/contrib_ops/cuda/moe/moe.cc | 4 ++++ onnxruntime/contrib_ops/cuda/moe/moe.h | 4 ++++ onnxruntime/contrib_ops/cuda/moe/moe_base.h | 4 ++++ onnxruntime/test/contrib_ops/moe_test.cc | 4 ++++ 25 files changed, 106 insertions(+), 3 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 2331562d4a3bd..cf86550e03a48 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -97,6 +97,7 @@ option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) +option(onnxruntime_USE_CUTLASS "Build with Cutlass support" OFF) option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF) option(onnxruntime_USE_AVX "Use AVX instructions" OFF) @@ -713,11 +714,13 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_USE_FLASH_ATTENTION) message( STATUS "Enable flash attention for CUDA EP") + set(onnxruntime_USE_CUTLASS ON) # Flash attention requires cutlass list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1) endif() if (onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) message( STATUS "Enable memory efficient attention for CUDA EP") + set(onnxruntime_USE_CUTLASS ON) # Memory efficient attention requires cutlass list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1) endif() @@ -886,6 +889,11 @@ function(onnxruntime_set_compile_flags target_name) if (onnxruntime_ENABLE_ATEN) target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) endif() + + if (onnxruntime_USE_CUTLASS) + target_compile_definitions(${target_name} PRIVATE USE_CUTLASS) + endif() + set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON) if (onnxruntime_USE_CUDA) # Suppress a "conversion_function_not_usable" warning in gsl/span diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index 983eecdd88235..efc708bd681c0 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -1,4 +1,4 @@ -if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) +if (onnxruntime_USE_CUTLASS) include(FetchContent) FetchContent_Declare( cutlass diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 40a667ffd5d83..9b989dac9a94b 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" @@ -202,3 +204,5 @@ Status ShardedMoE::SynchronizeExpertsStartIndex(AllocatorPtr& allocator, } // namespace cuda } // namespace contrib } // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h index 5ea4ae59c4020..cbd483fddab78 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.h @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" @@ -34,3 +36,5 @@ class ShardedMoE final : public NcclKernel, public MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 7875ac75b8188..be7e9f6a8225e 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -70,8 +70,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); +#ifdef USE_CUTLASS class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MoE); +#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); @@ -165,8 +167,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllR class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll); +#ifdef USE_CUTLASS class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ShardedMoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ShardedMoE); +#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul); @@ -266,8 +270,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, +#ifdef USE_CUTLASS BuildKernelCreateInfo, BuildKernelCreateInfo, +#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -367,8 +373,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, +#ifdef USE_CUTLASS BuildKernelCreateInfo, BuildKernelCreateInfo, +#endif BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h index 86136ea244e23..9b97690fe70fd 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/compute_occupancy.h @@ -13,6 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#ifdef USE_CUTLASS + #pragma once #include @@ -49,3 +52,5 @@ inline int compute_occupancy_for_kernel() { } } // namespace ort_fastertransformer + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc index 5d4c6793ec995..f0abd46572a90 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/cutlass_heuristic.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifdef USE_CUTLASS #include "cutlass_heuristic.h" @@ -185,3 +186,5 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector @@ -62,3 +64,5 @@ class MoeGemmRunner { }; } // namespace ort_fastertransformer + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu index 1d9a249db4237..1d0dfe7c5a647 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_fp16.cu @@ -14,8 +14,12 @@ * limitations under the License. */ +#ifdef USE_CUTLASS + #include "moe_gemm_kernels_template.h" namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu index 7b250e6ca9060..7a5d97902ee8f 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp32_fp32.cu @@ -14,8 +14,12 @@ * limitations under the License. */ +#ifdef USE_CUTLASS + #include "moe_gemm_kernels_template.h" namespace ort_fastertransformer { template class MoeGemmRunner; } // namespace ort_fastertransformer + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index 66950c9b65970..3fd0fc47055a5 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -14,6 +14,8 @@ * limitations under the License. */ +#ifdef USE_CUTLASS + // Ignore CUTLASS warnings about type punning #ifdef __GNUC__ #pragma GCC diagnostic push @@ -426,3 +428,5 @@ void MoeGemmRunner::moe_gemm(const T* A, const WeightType* B, con } } // namespace ort_fastertransformer + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index f4f2b49032d23..9232e8d012933 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -16,6 +16,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #include #include #include @@ -898,3 +900,5 @@ template void finalize_moe_routing_kernelLauncher(const half*, half*, const half cudaStream_t); } // namespace ort_fastertransformer + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index 5cc2a3f79f003..f09471de1cc2e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -16,6 +16,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #pragma once #include "moe_gemm_kernels.h" @@ -172,4 +174,6 @@ class CutlassMoeFCRunner> { } // namespace layout } // namespace cutlass + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 3f26a274109ad..0da06192e266b 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "moe.h" @@ -117,3 +119,5 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { } // namespace cuda } // namespace contrib } // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.h b/onnxruntime/contrib_ops/cuda/moe/moe.h index c4d8c4dc64c57..710b914f0633d 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe.h @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #pragma once #include "contrib_ops/cuda/moe/ft_moe/moe_kernel.h" @@ -24,3 +26,5 @@ class MoE final : public CudaKernel, public MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime + +#endif diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index f55a7cde2e208..dc8b9d57f79f6 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #pragma once #include "core/common/common.h" @@ -170,3 +172,5 @@ class MoEBase { } // namespace cuda } // namespace contrib } // namespace onnxruntime + +#endif diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index ebb0261deefa5..844cc877f2568 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef USE_CUTLASS + #include "gtest/gtest.h" #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" @@ -421,3 +423,5 @@ TEST(MoETest, MoETest_Relu) { } // namespace test } // namespace onnxruntime + +#endif From 315938bce57f140635a03157b82f6f275be3c9d4 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Fri, 8 Dec 2023 13:28:14 +0000 Subject: [PATCH 4/9] fix --- include/onnxruntime/core/framework/float16.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/include/onnxruntime/core/framework/float16.h b/include/onnxruntime/core/framework/float16.h index 2d289d6febb8d..6e1f99bbccf30 100644 --- a/include/onnxruntime/core/framework/float16.h +++ b/include/onnxruntime/core/framework/float16.h @@ -89,11 +89,7 @@ struct MLFloat16 : onnxruntime_float16::Float16Impl { struct BFloat16 : onnxruntime_float16::BFloat16Impl { using Base = onnxruntime_float16::BFloat16Impl; -#if defined(__HIP__) ORT_HOST_DEVICE BFloat16() = default; -#else - BFloat16() = default; -#endif struct FromBitsT {}; static constexpr ORT_HOST_DEVICE FromBitsT FromBits() noexcept { return FromBitsT(); } From ca05a1ea9f93595296386da7c1a51e132a0dbd23 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Fri, 8 Dec 2023 13:42:35 +0000 Subject: [PATCH 5/9] Revert "fix" This reverts commit 315938bce57f140635a03157b82f6f275be3c9d4. --- include/onnxruntime/core/framework/float16.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/onnxruntime/core/framework/float16.h b/include/onnxruntime/core/framework/float16.h index 6e1f99bbccf30..2d289d6febb8d 100644 --- a/include/onnxruntime/core/framework/float16.h +++ b/include/onnxruntime/core/framework/float16.h @@ -89,7 +89,11 @@ struct MLFloat16 : onnxruntime_float16::Float16Impl { struct BFloat16 : onnxruntime_float16::BFloat16Impl { using Base = onnxruntime_float16::BFloat16Impl; +#if defined(__HIP__) ORT_HOST_DEVICE BFloat16() = default; +#else + BFloat16() = default; +#endif struct FromBitsT {}; static constexpr ORT_HOST_DEVICE FromBitsT FromBits() noexcept { return FromBitsT(); } From cf989b166d309cef515fce7223ac5cf1c64fa7a7 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Fri, 8 Dec 2023 13:54:01 +0000 Subject: [PATCH 6/9] fix --- cmake/CMakeLists.txt | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index cf86550e03a48..2e3bedaf6d901 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -95,9 +95,9 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) +cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with Cutlass support" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) -option(onnxruntime_USE_CUTLASS "Build with Cutlass support" OFF) option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF) option(onnxruntime_USE_AVX "Use AVX instructions" OFF) @@ -693,16 +693,20 @@ if (onnxruntime_USE_CUDA) enable_language(CUDA) message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") - if (onnxruntime_DISABLE_CONTRIB_OPS) - set(onnxruntime_USE_FLASH_ATTENTION OFF) - set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) - endif() if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) - message( STATUS "Turn off flash attention since CUDA compiler version < 11.6") - set(onnxruntime_USE_FLASH_ATTENTION OFF) - set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) + message( STATUS "Turn off cutlass since CUDA compiler version < 11.6") + set(onnxruntime_USE_CUTLASS OFF) endif() else() + set(onnxruntime_USE_CUTLASS OFF) +endif() + +if (not onnxruntime_USE_CUTLASS or onnxruntime_DISABLE_CONTRIB_OPS) + if (onnxruntime_DISABLE_CONTRIB_OPS) + message( STATUS "Turn off flash attention/memory efficient attention since contrib ops are disabled") + else() + message( STATUS "Turn off flash attention/memory efficient attention since cutlass is not enabled") + endif() set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() @@ -714,13 +718,11 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_USE_FLASH_ATTENTION) message( STATUS "Enable flash attention for CUDA EP") - set(onnxruntime_USE_CUTLASS ON) # Flash attention requires cutlass list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1) endif() if (onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) message( STATUS "Enable memory efficient attention for CUDA EP") - set(onnxruntime_USE_CUTLASS ON) # Memory efficient attention requires cutlass list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1) endif() From e94b18c8a412ee745140c8f29224a382a49a5828 Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Fri, 8 Dec 2023 13:57:05 +0000 Subject: [PATCH 7/9] fix --- cmake/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 2e3bedaf6d901..d6620160cbbad 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -701,7 +701,7 @@ else() set(onnxruntime_USE_CUTLASS OFF) endif() -if (not onnxruntime_USE_CUTLASS or onnxruntime_DISABLE_CONTRIB_OPS) +if (NOT onnxruntime_USE_CUTLASS OR onnxruntime_DISABLE_CONTRIB_OPS) if (onnxruntime_DISABLE_CONTRIB_OPS) message( STATUS "Turn off flash attention/memory efficient attention since contrib ops are disabled") else() From b7486b1be2f48e2640c0ae7793460314ce8221ce Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Mon, 11 Dec 2023 01:22:23 +0000 Subject: [PATCH 8/9] enable cutlass for win+gpu --- cmake/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index d6620160cbbad..2416d48b32bb7 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -95,7 +95,7 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) -cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with Cutlass support" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) +cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cCutlass support" ON "onnxruntime_USE_CUDA" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) From fa76634ddc14db47f74f4f5583a0ed51c1db518e Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Mon, 11 Dec 2023 01:23:25 +0000 Subject: [PATCH 9/9] typo --- cmake/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 2416d48b32bb7..30d9a39ad57f7 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -95,7 +95,7 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) -cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cCutlass support" ON "onnxruntime_USE_CUDA" OFF) +cmake_dependent_option(onnxruntime_USE_CUTLASS "Build with cutlass support" ON "onnxruntime_USE_CUDA" OFF) cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)