diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index 56b541f5256bf..064b6dd392437 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -251,15 +251,21 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &ctransb, sizeof(ctransb))); +#if CUDA_VERSION >= 11060 + // CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET exists from https://docs.nvidia.com/cuda/archive/11.6.0/pdf/CUBLAS_Library.pdf if (sm_count_ != 0) { int math_sm_count = static_cast(sm_count_); CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, sizeof(math_sm_count))); } +#endif if (has_scales) { // gemm float 8 +#if CUDA_VERSION >= 11080 + // CUBLASLT_MATMUL_DESC_FAST_ACCUM, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + // CUBLASLT_MATMUL_DESC_D_SCALE_POINTER exist from https://docs.nvidia.com/cuda/archive/11.8.0/pdf/CUBLAS_Library.pdf const int8_t ifast_accumulation_mode = 1; CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, @@ -274,6 +280,7 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y, sizeof(p_scale_b))); +#endif // float 8 #if !defined(DISABLE_FLOAT8_TYPES)