Skip to content

Commit

Permalink
Merge branch 'microsoft:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel authored Dec 8, 2023
2 parents f13d1a7 + 44b5843 commit 6cfa0ab
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/cuda/math/gemm_float8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,21 @@ Status GemmFloat8::ComputeGemm(
CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &ctransb, sizeof(ctransb)));

#if CUDA_VERSION >= 11060
// CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET exists from https://docs.nvidia.com/cuda/archive/11.6.0/pdf/CUBLAS_Library.pdf
if (sm_count_ != 0) {
int math_sm_count = static_cast<int>(sm_count_);
CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count,
sizeof(math_sm_count)));
}
#endif

if (has_scales) {
// gemm float 8
#if CUDA_VERSION >= 11080
// CUBLASLT_MATMUL_DESC_FAST_ACCUM, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
// CUBLASLT_MATMUL_DESC_D_SCALE_POINTER exist from https://docs.nvidia.com/cuda/archive/11.8.0/pdf/CUBLAS_Library.pdf
const int8_t ifast_accumulation_mode = 1;
CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
operationDesc,
Expand All @@ -274,6 +280,7 @@ Status GemmFloat8::ComputeGemm(
CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y,
sizeof(p_scale_b)));
#endif

// float 8
#if !defined(DISABLE_FLOAT8_TYPES)
Expand Down

0 comments on commit 6cfa0ab

Please sign in to comment.