From de46ba5807b8c779e911de4987396c7e151f3273 Mon Sep 17 00:00:00 2001 From: "Peng Wang(AI FWK)" Date: Fri, 8 Dec 2023 10:29:00 +0000 Subject: [PATCH 1/2] fix --- onnxruntime/contrib_ops/cuda/math/gemm_float8.cu | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index 56b541f5256bf..a22e1dbae3f4f 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -251,15 +251,21 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &ctransb, sizeof(ctransb))); +#if CUDA_VERSION >= 11060 + // CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET exists from https://docs.nvidia.com/cuda/archive/11.6.0/pdf/CUBLAS_Library.pdf if (sm_count_ != 0) { int math_sm_count = static_cast(sm_count_); CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, sizeof(math_sm_count))); } +#endif if (has_scales) { // gemm float 8 +#if CUDA_VERSION >= 11080 + // CUBLASLT_MATMUL_DESC_FAST_ACCUM, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + // CUBLASLT_MATMUL_DESC_D_SCALE_POINTER exist from https://docs.nvidia.com/cuda/archive/11.8.0/pdf/CUBLAS_Library.pdf const int8_t ifast_accumulation_mode = 1; CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, @@ -274,6 +280,7 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y, sizeof(p_scale_b))); +#endif // float 8 #if !defined(DISABLE_FLOAT8_TYPES) From a200264eee4838984d49aade5f62a8e42d45e3ee Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Fri, 8 Dec 2023 10:31:36 +0000 Subject: [PATCH 2/2] format --- onnxruntime/contrib_ops/cuda/math/gemm_float8.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index a22e1dbae3f4f..064b6dd392437 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -280,7 +280,7 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y, sizeof(p_scale_b))); -#endif +#endif // float 8 #if !defined(DISABLE_FLOAT8_TYPES)