From 59d1232ea76a7d760847244f9fffa3e75a590d32 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Oct 2023 10:26:58 +0300 Subject: [PATCH 1/7] cuda : prints wip --- ggml-cuda.cu | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ba0cd5a7d3f1e..7bbef0a1a7864 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6304,6 +6304,7 @@ inline void ggml_cuda_op_mul_mat_cublas( const half alpha_f16 = 1.0f; const half beta_f16 = 0.0f; + //printf("F16: row_diff: %ld, src1_ncols: %ld, ne10: %ld, ne00: %ld, ldc: %d\n", row_diff, src1_ncols, ne10, ne00, ldc); CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); CUBLAS_CHECK( cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, @@ -7250,6 +7251,12 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true); } else { + //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); + //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); + //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]); + //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); + //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); + //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } } From 52af78260884013abc3f2fc3669e68f45ec2bfb5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Oct 2023 13:14:24 +0300 Subject: [PATCH 2/7] cuda : new cublas gemm branch for multi-batch quantized src0 --- ggml-cuda.cu | 115 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7bbef0a1a7864..ca49d73bfc4de 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6304,7 +6304,6 @@ inline void ggml_cuda_op_mul_mat_cublas( const half alpha_f16 = 1.0f; const half beta_f16 = 0.0f; - //printf("F16: row_diff: %ld, src1_ncols: %ld, ne10: %ld, ne00: %ld, ldc: %d\n", row_diff, src1_ncols, ne10, ne00, ldc); CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); CUBLAS_CHECK( cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, @@ -7049,9 +7048,10 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); } -static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -7202,6 +7202,115 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_cuda_pool_free(dst_f16, dst_as); } +static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + // require tensor cores + const int compute_capability = g_compute_capabilities[id]; + GGML_ASSERT(compute_capability >= CC_VOLTA); + + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + + //GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + + GGML_ASSERT(ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00); + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; GGML_UNUSED(ne02); + const int64_t ne03 = src0->ne[3]; GGML_UNUSED(ne03); + + const int64_t nb01 = src0->nb[1]; GGML_UNUSED(nb01); + const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02); + const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03); + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; GGML_UNUSED(ne12); + const int64_t ne13 = src1->ne[3]; GGML_UNUSED(ne13); + + const int64_t nb11 = src1->nb[1]; GGML_UNUSED(nb11); + const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); + const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); + + const int64_t ne1 = ggml_nelements(src1); + const int64_t ne = ggml_nelements(dst); + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + void * src0_ddq = src0_extra->data_device[g_main_device]; + + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + if (ggml_is_contiguous(src0)) { + // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 + half * src0_as_f16 = nullptr; + size_t src0_as = 0; + if (src0->type != GGML_TYPE_F16) { + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + const size_t ne = ne01*ne00; + src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); + to_fp16_cuda(src0_ddq, src0_as_f16, ne, main_stream); + } + + const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_ddq : src0_as_f16; + + half * src1_as_f16 = nullptr; + size_t src1_as = 0; + if (src1->type != GGML_TYPE_F16) { + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + const size_t ne = ne11*ne10; + src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); + to_fp16_cuda(src1_ddf, src1_as_f16, ne, main_stream); + } + + const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf : src1_as_f16; + + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne01*ne11 * sizeof(half), &dst_as); + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); + CUBLAS_CHECK( + cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha_f16, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_f16, CUDA_R_16F, ne01, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16, dst_ddf, ne01*ne11, main_stream); + + ggml_cuda_pool_free(dst_f16, dst_as); + + if (src0_as != 0) { + ggml_cuda_pool_free(src0_as_f16, src0_as); + } + + if (src1_as != 0) { + ggml_cuda_pool_free(src1_as_f16, src1_as); + } + } else { + GGML_ASSERT(false && "not implemented"); + } +} + static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; @@ -7231,6 +7340,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); + } else if (all_on_device && (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->ne[1] > 1) { + ggml_cuda_mul_mat_mat_deq_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { From 16b60dd75c8c89b726da5e9252454791fa1300b7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Oct 2023 14:00:21 +0300 Subject: [PATCH 3/7] cuda : add F32 sgemm branch --- ggml-cuda.cu | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ca49d73bfc4de..75e0dddf90b1c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7252,7 +7252,8 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - if (ggml_is_contiguous(src0)) { +#if 0 + { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 half * src0_as_f16 = nullptr; size_t src0_as = 0; @@ -7306,9 +7307,40 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm if (src1_as != 0) { ggml_cuda_pool_free(src1_as_f16, src1_as); } - } else { - GGML_ASSERT(false && "not implemented"); } +#else + { + // convert src0 to fp32, multiply as fp32 + float * src0_as_f32 = nullptr; + size_t src0_as = 0; + if (src0->type != GGML_TYPE_F32) { + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); + GGML_ASSERT(to_fp32_cuda != nullptr); + const size_t ne = ne01*ne00; + src0_as_f32 = (float *) ggml_cuda_pool_malloc(ne * sizeof(float), &src0_as); + to_fp32_cuda(src0_ddq, src0_as_f32, ne, main_stream); + } + + const float * src0_ptr = src0->type == GGML_TYPE_F32 ? (const float *) src0_ddq : src0_as_f32; + + const float * src1_ptr = (const float *) src1_ddf; + + const float alpha = 1.0f; + const float beta = 0.0f; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); + CUBLAS_CHECK( + cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, src0_ptr, ne00, + src1_ptr, ne10, + &beta, dst_ddf, ne01)); + + if (src0_as != 0) { + ggml_cuda_pool_free(src0_as_f32, src0_as); + } + } +#endif } static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { From a3c28439d3c974db10092201e6228da709c801d0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Oct 2023 15:07:34 +0300 Subject: [PATCH 4/7] cuda : fine-tune >= VOLTA params + use MMQ only for small batches --- ggml-cuda.cu | 46 ++++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 75e0dddf90b1c..bafe080884073 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3554,8 +3554,8 @@ static __device__ __forceinline__ void mul_mat_q( #define MMQ_X_Q4_0_RDNA1 64 #define MMQ_Y_Q4_0_RDNA1 64 #define NWARPS_Q4_0_RDNA1 8 -#define MMQ_X_Q4_0_AMPERE 64 -#define MMQ_Y_Q4_0_AMPERE 128 +#define MMQ_X_Q4_0_AMPERE 4 +#define MMQ_Y_Q4_0_AMPERE 32 #define NWARPS_Q4_0_AMPERE 4 #define MMQ_X_Q4_0_PASCAL 64 #define MMQ_Y_Q4_0_PASCAL 64 @@ -3615,8 +3615,8 @@ template static __global__ void #define MMQ_X_Q4_1_RDNA1 64 #define MMQ_Y_Q4_1_RDNA1 64 #define NWARPS_Q4_1_RDNA1 8 -#define MMQ_X_Q4_1_AMPERE 64 -#define MMQ_Y_Q4_1_AMPERE 128 +#define MMQ_X_Q4_1_AMPERE 4 +#define MMQ_Y_Q4_1_AMPERE 32 #define NWARPS_Q4_1_AMPERE 4 #define MMQ_X_Q4_1_PASCAL 64 #define MMQ_Y_Q4_1_PASCAL 64 @@ -3678,8 +3678,8 @@ template static __global__ void #define MMQ_X_Q5_0_RDNA1 64 #define MMQ_Y_Q5_0_RDNA1 64 #define NWARPS_Q5_0_RDNA1 8 -#define MMQ_X_Q5_0_AMPERE 128 -#define MMQ_Y_Q5_0_AMPERE 64 +#define MMQ_X_Q5_0_AMPERE 4 +#define MMQ_Y_Q5_0_AMPERE 32 #define NWARPS_Q5_0_AMPERE 4 #define MMQ_X_Q5_0_PASCAL 64 #define MMQ_Y_Q5_0_PASCAL 64 @@ -3739,8 +3739,8 @@ template static __global__ void #define MMQ_X_Q5_1_RDNA1 64 #define MMQ_Y_Q5_1_RDNA1 64 #define NWARPS_Q5_1_RDNA1 8 -#define MMQ_X_Q5_1_AMPERE 128 -#define MMQ_Y_Q5_1_AMPERE 64 +#define MMQ_X_Q5_1_AMPERE 4 +#define MMQ_Y_Q5_1_AMPERE 32 #define NWARPS_Q5_1_AMPERE 4 #define MMQ_X_Q5_1_PASCAL 64 #define MMQ_Y_Q5_1_PASCAL 64 @@ -3800,8 +3800,8 @@ mul_mat_q5_1( #define MMQ_X_Q8_0_RDNA1 64 #define MMQ_Y_Q8_0_RDNA1 64 #define NWARPS_Q8_0_RDNA1 8 -#define MMQ_X_Q8_0_AMPERE 128 -#define MMQ_Y_Q8_0_AMPERE 64 +#define MMQ_X_Q8_0_AMPERE 4 +#define MMQ_Y_Q8_0_AMPERE 32 #define NWARPS_Q8_0_AMPERE 4 #define MMQ_X_Q8_0_PASCAL 64 #define MMQ_Y_Q8_0_PASCAL 64 @@ -3861,8 +3861,8 @@ template static __global__ void #define MMQ_X_Q2_K_RDNA1 128 #define MMQ_Y_Q2_K_RDNA1 32 #define NWARPS_Q2_K_RDNA1 8 -#define MMQ_X_Q2_K_AMPERE 64 -#define MMQ_Y_Q2_K_AMPERE 128 +#define MMQ_X_Q2_K_AMPERE 4 +#define MMQ_Y_Q2_K_AMPERE 32 #define NWARPS_Q2_K_AMPERE 4 #define MMQ_X_Q2_K_PASCAL 64 #define MMQ_Y_Q2_K_PASCAL 64 @@ -3922,8 +3922,8 @@ mul_mat_q2_K( #define MMQ_X_Q3_K_RDNA1 32 #define MMQ_Y_Q3_K_RDNA1 128 #define NWARPS_Q3_K_RDNA1 8 -#define MMQ_X_Q3_K_AMPERE 128 -#define MMQ_Y_Q3_K_AMPERE 128 +#define MMQ_X_Q3_K_AMPERE 4 +#define MMQ_Y_Q3_K_AMPERE 32 #define NWARPS_Q3_K_AMPERE 4 #define MMQ_X_Q3_K_PASCAL 64 #define MMQ_Y_Q3_K_PASCAL 64 @@ -3985,8 +3985,8 @@ template static __global__ void #define MMQ_X_Q4_K_RDNA1 32 #define MMQ_Y_Q4_K_RDNA1 64 #define NWARPS_Q4_K_RDNA1 8 -#define MMQ_X_Q4_K_AMPERE 64 -#define MMQ_Y_Q4_K_AMPERE 128 +#define MMQ_X_Q4_K_AMPERE 4 +#define MMQ_Y_Q4_K_AMPERE 32 #define NWARPS_Q4_K_AMPERE 4 #define MMQ_X_Q4_K_PASCAL 64 #define MMQ_Y_Q4_K_PASCAL 64 @@ -4048,8 +4048,8 @@ template static __global__ void #define MMQ_X_Q5_K_RDNA1 32 #define MMQ_Y_Q5_K_RDNA1 64 #define NWARPS_Q5_K_RDNA1 8 -#define MMQ_X_Q5_K_AMPERE 64 -#define MMQ_Y_Q5_K_AMPERE 128 +#define MMQ_X_Q5_K_AMPERE 4 +#define MMQ_Y_Q5_K_AMPERE 32 #define NWARPS_Q5_K_AMPERE 4 #define MMQ_X_Q5_K_PASCAL 64 #define MMQ_Y_Q5_K_PASCAL 64 @@ -4109,8 +4109,8 @@ mul_mat_q5_K( #define MMQ_X_Q6_K_RDNA1 32 #define MMQ_Y_Q6_K_RDNA1 64 #define NWARPS_Q6_K_RDNA1 8 -#define MMQ_X_Q6_K_AMPERE 64 -#define MMQ_Y_Q6_K_AMPERE 64 +#define MMQ_X_Q6_K_AMPERE 4 +#define MMQ_Y_Q6_K_AMPERE 32 #define NWARPS_Q6_K_AMPERE 4 #define MMQ_X_Q6_K_PASCAL 64 #define MMQ_Y_Q6_K_PASCAL 64 @@ -7252,7 +7252,7 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; -#if 0 +#if 1 { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 half * src0_as_f16 = nullptr; @@ -7309,6 +7309,7 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm } } #else + // NOTE: this seems faster for tiny models and small batch-size { // convert src0 to fp32, multiply as fp32 float * src0_as_f32 = nullptr; @@ -7372,7 +7373,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); - } else if (all_on_device && (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->ne[1] > 1) { + } else if (all_on_device && (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->ne[1] > 32) { + // F16 and quantized src0 + high-batch src1 ggml_cuda_mul_mat_mat_deq_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); From 4c6744b526106bdd96c2fa54da2a8dff5da05240 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Oct 2023 18:25:13 +0300 Subject: [PATCH 5/7] cuda : remove duplicated cuBLAS GEMM code --- ggml-cuda.cu | 169 +++++---------------------------------------------- 1 file changed, 15 insertions(+), 154 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index bafe080884073..27824f47f1ae4 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7202,151 +7202,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_cuda_pool_free(dst_f16, dst_as); } -static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - int id; - CUDA_CHECK(cudaGetDevice(&id)); - - // require tensor cores - const int compute_capability = g_compute_capabilities[id]; - GGML_ASSERT(compute_capability >= CC_VOLTA); - - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - - //GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); - - GGML_ASSERT(ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00); - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; GGML_UNUSED(ne02); - const int64_t ne03 = src0->ne[3]; GGML_UNUSED(ne03); - - const int64_t nb01 = src0->nb[1]; GGML_UNUSED(nb01); - const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02); - const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03); - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; GGML_UNUSED(ne12); - const int64_t ne13 = src1->ne[3]; GGML_UNUSED(ne13); - - const int64_t nb11 = src1->nb[1]; GGML_UNUSED(nb11); - const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); - const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); - - const int64_t ne1 = ggml_nelements(src1); - const int64_t ne = ggml_nelements(dst); - - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); - cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; - CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); - - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - void * src0_ddq = src0_extra->data_device[g_main_device]; - - ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; - - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - -#if 1 - { - // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 - half * src0_as_f16 = nullptr; - size_t src0_as = 0; - if (src0->type != GGML_TYPE_F16) { - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); - GGML_ASSERT(to_fp16_cuda != nullptr); - const size_t ne = ne01*ne00; - src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); - to_fp16_cuda(src0_ddq, src0_as_f16, ne, main_stream); - } - - const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_ddq : src0_as_f16; - - half * src1_as_f16 = nullptr; - size_t src1_as = 0; - if (src1->type != GGML_TYPE_F16) { - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); - GGML_ASSERT(to_fp16_cuda != nullptr); - const size_t ne = ne11*ne10; - src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); - to_fp16_cuda(src1_ddf, src1_as_f16, ne, main_stream); - } - - const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf : src1_as_f16; - - size_t dst_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne01*ne11 * sizeof(half), &dst_as); - - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; - - CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); - CUBLAS_CHECK( - cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha_f16, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta_f16, dst_f16, CUDA_R_16F, ne01, - CUBLAS_COMPUTE_16F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16, dst_ddf, ne01*ne11, main_stream); - - ggml_cuda_pool_free(dst_f16, dst_as); - - if (src0_as != 0) { - ggml_cuda_pool_free(src0_as_f16, src0_as); - } - - if (src1_as != 0) { - ggml_cuda_pool_free(src1_as_f16, src1_as); - } - } -#else - // NOTE: this seems faster for tiny models and small batch-size - { - // convert src0 to fp32, multiply as fp32 - float * src0_as_f32 = nullptr; - size_t src0_as = 0; - if (src0->type != GGML_TYPE_F32) { - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); - GGML_ASSERT(to_fp32_cuda != nullptr); - const size_t ne = ne01*ne00; - src0_as_f32 = (float *) ggml_cuda_pool_malloc(ne * sizeof(float), &src0_as); - to_fp32_cuda(src0_ddq, src0_as_f32, ne, main_stream); - } - - const float * src0_ptr = src0->type == GGML_TYPE_F32 ? (const float *) src0_ddq : src0_as_f32; - - const float * src1_ptr = (const float *) src1_ddf; - - const float alpha = 1.0f; - const float beta = 0.0f; - - CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); - CUBLAS_CHECK( - cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha, src0_ptr, ne00, - src1_ptr, ne10, - &beta, dst_ddf, ne01)); - - if (src0_as != 0) { - ggml_cuda_pool_free(src0_as_f32, src0_as); - } - } -#endif -} - static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && - src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; + const bool all_on_device = + (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && + (src1->backend == GGML_BACKEND_GPU) && + ( dst->backend == GGML_BACKEND_GPU); int64_t min_compute_capability = INT_MAX; for (int64_t id = 0; id < g_device_count; ++id) { @@ -7373,9 +7233,6 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); - } else if (all_on_device && (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->ne[1] > 32) { - // F16 and quantized src0 + high-batch src1 - ggml_cuda_mul_mat_mat_deq_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { @@ -7393,15 +7250,19 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); } } else { - if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) { + // ref: https://github.com/ggerganov/llama.cpp/pull/3776 + bool use_mul_mat_q = g_mul_mat_q && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); + + // TODO: better way to determine availability of tensor cores + // currently fails for GeForce GTX 1660 which is TURING arch but does not have tensor cores + if (min_compute_capability >= CC_VOLTA && src1->ne[1] > 32) { + // when tensor cores are available, use them for large batch size + use_mul_mat_q = false; + } + + if (use_mul_mat_q) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true); } else { - //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); - //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); - //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]); - //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); - //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); - //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } } From a4e15a36e4cd7120945eef560e591efaeb5fbd2b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Oct 2023 18:48:36 +0300 Subject: [PATCH 6/7] cuda : add CUDA_USE_TENSOR_CORES and GGML_CUDA_FORCE_MMQ macros --- ggml-cuda.cu | 122 +++++++++++++++++++++++++++++++++++++++++++-------- llama.cpp | 2 - llama.h | 2 +- 3 files changed, 104 insertions(+), 22 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 27824f47f1ae4..558cd0bd8861a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -87,6 +87,23 @@ #define CC_OFFSET_AMD 1000000 #define CC_RDNA2 (CC_OFFSET_AMD + 1030) +// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication +// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant +// for large computational tasks. the drawback is that this requires some extra amount of VRAM: +// - 7B quantum model: +100-200 MB +// - 13B quantum model: +200-400 MB +//#define GGML_CUDA_FORCE_MMQ + +// TODO: improve this to be correct for more hardware +// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores +// probably other such cases, and not sure what happens on AMD hardware +#if !defined(GGML_CUDA_FORCE_MMQ) +#define CUDA_USE_TENSOR_CORES +#endif + +// max batch size to use MMQ kernels when tensor cores are available +#define MMQ_MAX_BATCH_SIZE 32 + #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 @@ -470,7 +487,6 @@ static int g_device_count = -1; static int g_main_device = 0; static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; -static bool g_mul_mat_q = true; static void * g_scratch_buffer = nullptr; static size_t g_scratch_size = 0; // disabled by default @@ -3554,9 +3570,15 @@ static __device__ __forceinline__ void mul_mat_q( #define MMQ_X_Q4_0_RDNA1 64 #define MMQ_Y_Q4_0_RDNA1 64 #define NWARPS_Q4_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q4_0_AMPERE 4 #define MMQ_Y_Q4_0_AMPERE 32 #define NWARPS_Q4_0_AMPERE 4 +#else +#define MMQ_X_Q4_0_AMPERE 64 +#define MMQ_Y_Q4_0_AMPERE 128 +#define NWARPS_Q4_0_AMPERE 4 +#endif #define MMQ_X_Q4_0_PASCAL 64 #define MMQ_Y_Q4_0_PASCAL 64 #define NWARPS_Q4_0_PASCAL 8 @@ -3615,9 +3637,15 @@ template static __global__ void #define MMQ_X_Q4_1_RDNA1 64 #define MMQ_Y_Q4_1_RDNA1 64 #define NWARPS_Q4_1_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q4_1_AMPERE 4 #define MMQ_Y_Q4_1_AMPERE 32 #define NWARPS_Q4_1_AMPERE 4 +#else +#define MMQ_X_Q4_1_AMPERE 64 +#define MMQ_Y_Q4_1_AMPERE 128 +#define NWARPS_Q4_1_AMPERE 4 +#endif #define MMQ_X_Q4_1_PASCAL 64 #define MMQ_Y_Q4_1_PASCAL 64 #define NWARPS_Q4_1_PASCAL 8 @@ -3678,9 +3706,15 @@ template static __global__ void #define MMQ_X_Q5_0_RDNA1 64 #define MMQ_Y_Q5_0_RDNA1 64 #define NWARPS_Q5_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q5_0_AMPERE 4 #define MMQ_Y_Q5_0_AMPERE 32 #define NWARPS_Q5_0_AMPERE 4 +#else +#define MMQ_X_Q5_0_AMPERE 128 +#define MMQ_Y_Q5_0_AMPERE 64 +#define NWARPS_Q5_0_AMPERE 4 +#endif #define MMQ_X_Q5_0_PASCAL 64 #define MMQ_Y_Q5_0_PASCAL 64 #define NWARPS_Q5_0_PASCAL 8 @@ -3739,9 +3773,15 @@ template static __global__ void #define MMQ_X_Q5_1_RDNA1 64 #define MMQ_Y_Q5_1_RDNA1 64 #define NWARPS_Q5_1_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q5_1_AMPERE 4 #define MMQ_Y_Q5_1_AMPERE 32 #define NWARPS_Q5_1_AMPERE 4 +#else +#define MMQ_X_Q5_1_AMPERE 128 +#define MMQ_Y_Q5_1_AMPERE 64 +#define NWARPS_Q5_1_AMPERE 4 +#endif #define MMQ_X_Q5_1_PASCAL 64 #define MMQ_Y_Q5_1_PASCAL 64 #define NWARPS_Q5_1_PASCAL 8 @@ -3800,9 +3840,15 @@ mul_mat_q5_1( #define MMQ_X_Q8_0_RDNA1 64 #define MMQ_Y_Q8_0_RDNA1 64 #define NWARPS_Q8_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q8_0_AMPERE 4 #define MMQ_Y_Q8_0_AMPERE 32 #define NWARPS_Q8_0_AMPERE 4 +#else +#define MMQ_X_Q8_0_AMPERE 128 +#define MMQ_Y_Q8_0_AMPERE 64 +#define NWARPS_Q8_0_AMPERE 4 +#endif #define MMQ_X_Q8_0_PASCAL 64 #define MMQ_Y_Q8_0_PASCAL 64 #define NWARPS_Q8_0_PASCAL 8 @@ -3861,9 +3907,15 @@ template static __global__ void #define MMQ_X_Q2_K_RDNA1 128 #define MMQ_Y_Q2_K_RDNA1 32 #define NWARPS_Q2_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q2_K_AMPERE 4 #define MMQ_Y_Q2_K_AMPERE 32 #define NWARPS_Q2_K_AMPERE 4 +#else +#define MMQ_X_Q2_K_AMPERE 64 +#define MMQ_Y_Q2_K_AMPERE 128 +#define NWARPS_Q2_K_AMPERE 4 +#endif #define MMQ_X_Q2_K_PASCAL 64 #define MMQ_Y_Q2_K_PASCAL 64 #define NWARPS_Q2_K_PASCAL 8 @@ -3922,9 +3974,15 @@ mul_mat_q2_K( #define MMQ_X_Q3_K_RDNA1 32 #define MMQ_Y_Q3_K_RDNA1 128 #define NWARPS_Q3_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q3_K_AMPERE 4 #define MMQ_Y_Q3_K_AMPERE 32 #define NWARPS_Q3_K_AMPERE 4 +#else +#define MMQ_X_Q3_K_AMPERE 128 +#define MMQ_Y_Q3_K_AMPERE 128 +#define NWARPS_Q3_K_AMPERE 4 +#endif #define MMQ_X_Q3_K_PASCAL 64 #define MMQ_Y_Q3_K_PASCAL 64 #define NWARPS_Q3_K_PASCAL 8 @@ -3985,9 +4043,15 @@ template static __global__ void #define MMQ_X_Q4_K_RDNA1 32 #define MMQ_Y_Q4_K_RDNA1 64 #define NWARPS_Q4_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q4_K_AMPERE 4 #define MMQ_Y_Q4_K_AMPERE 32 #define NWARPS_Q4_K_AMPERE 4 +#else +#define MMQ_X_Q4_K_AMPERE 64 +#define MMQ_Y_Q4_K_AMPERE 128 +#define NWARPS_Q4_K_AMPERE 4 +#endif #define MMQ_X_Q4_K_PASCAL 64 #define MMQ_Y_Q4_K_PASCAL 64 #define NWARPS_Q4_K_PASCAL 8 @@ -4048,9 +4112,15 @@ template static __global__ void #define MMQ_X_Q5_K_RDNA1 32 #define MMQ_Y_Q5_K_RDNA1 64 #define NWARPS_Q5_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q5_K_AMPERE 4 #define MMQ_Y_Q5_K_AMPERE 32 #define NWARPS_Q5_K_AMPERE 4 +#else +#define MMQ_X_Q5_K_AMPERE 64 +#define MMQ_Y_Q5_K_AMPERE 128 +#define NWARPS_Q5_K_AMPERE 4 +#endif #define MMQ_X_Q5_K_PASCAL 64 #define MMQ_Y_Q5_K_PASCAL 64 #define NWARPS_Q5_K_PASCAL 8 @@ -4109,9 +4179,15 @@ mul_mat_q5_K( #define MMQ_X_Q6_K_RDNA1 32 #define MMQ_Y_Q6_K_RDNA1 64 #define NWARPS_Q6_K_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) #define MMQ_X_Q6_K_AMPERE 4 #define MMQ_Y_Q6_K_AMPERE 32 #define NWARPS_Q6_K_AMPERE 4 +#else +#define MMQ_X_Q6_K_AMPERE 64 +#define MMQ_Y_Q6_K_AMPERE 64 +#define NWARPS_Q6_K_AMPERE 4 +#endif #define MMQ_X_Q6_K_PASCAL 64 #define MMQ_Y_Q6_K_PASCAL 64 #define NWARPS_Q6_K_PASCAL 8 @@ -5663,6 +5739,16 @@ void ggml_init_cublas() { CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0; +#if defined(GGML_CUDA_FORCE_MMQ) + fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); +#else + fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); +#endif +#if defined(CUDA_USE_TENSOR_CORES) + fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: yes\n", __func__); +#else + fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: no\n", __func__); +#endif fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count); for (int id = 0; id < g_device_count; ++id) { cudaDeviceProp prop; @@ -6347,7 +6433,7 @@ inline void ggml_cuda_op_mul_mat_cublas( cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, row_diff, src1_ncols, ne10, &alpha, src0_ddf_i, ne00, - src1_ddf_i, ne10, + src1_ddf_i, ne10, &beta, dst_dd_i, ldc)); if (src0_as != 0) { @@ -7204,18 +7290,23 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool all_on_device = - (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && + (src0->backend == GGML_BACKEND_GPU) && (src1->backend == GGML_BACKEND_GPU) && ( dst->backend == GGML_BACKEND_GPU); int64_t min_compute_capability = INT_MAX; for (int64_t id = 0; id < g_device_count; ++id) { - if (min_compute_capability > g_compute_capabilities[id] - && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { + if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { min_compute_capability = g_compute_capabilities[id]; } } +#ifdef CUDA_USE_TENSOR_CORES + const bool use_tensor_cores = true; +#else + const bool use_tensor_cores = false; +#endif + // debug helpers //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); @@ -7224,20 +7315,19 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { // KQ single-batch ggml_cuda_mul_mat_vec_p021(src0, src1, dst); - } else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + } else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_cuda_mul_mat_vec_nc(src0, src1, dst); - } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) { // KQ + KQV multi-batch ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { - #ifdef GGML_CUDA_FORCE_DMMV const bool use_mul_mat_vec_q = false; #else @@ -7250,13 +7340,11 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); } } else { - // ref: https://github.com/ggerganov/llama.cpp/pull/3776 - bool use_mul_mat_q = g_mul_mat_q && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); + bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); - // TODO: better way to determine availability of tensor cores - // currently fails for GeForce GTX 1660 which is TURING arch but does not have tensor cores - if (min_compute_capability >= CC_VOLTA && src1->ne[1] > 32) { - // when tensor cores are available, use them for large batch size + // when tensor cores are available, use them for large batch size + // ref: https://github.com/ggerganov/llama.cpp/pull/3776 + if (use_tensor_cores && min_compute_capability >= CC_VOLTA && src1->ne[1] > MMQ_MAX_BATCH_SIZE) { use_mul_mat_q = false; } @@ -7614,10 +7702,6 @@ void ggml_cuda_set_main_device(const int main_device) { } } -void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) { - g_mul_mat_q = mul_mat_q; -} - void ggml_cuda_set_scratch_size(const size_t scratch_size) { // this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously // it still won't always work as expected, but it's better than nothing diff --git a/llama.cpp b/llama.cpp index 61f30c3982f18..cc8669b0e9e23 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5959,8 +5959,6 @@ static int llama_decode_internal( } } - ggml_cuda_set_mul_mat_q(cparams.mul_mat_q); - // HACK: ggml-alloc may change the tensor backend when reusing a parent, so force output to be on the CPU here if needed if (!lctx.embedding.empty()) { embeddings->backend = GGML_BACKEND_CPU; diff --git a/llama.h b/llama.h index 2f2fee0e2ff9f..beac9a0cedd76 100644 --- a/llama.h +++ b/llama.h @@ -178,7 +178,7 @@ extern "C" { float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model // Keep the booleans together to avoid misalignment during copy-by-value. - bool mul_mat_q; // if true, use experimental mul_mat_q kernels + bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true) bool f16_kv; // use fp16 for KV cache, fp32 otherwise bool logits_all; // the llama_eval() call computes all logits, not just the last one bool embedding; // embedding mode only From 49af767fadfc44dd079d49038089b4ee99d77e0c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 27 Oct 2023 13:21:04 +0300 Subject: [PATCH 7/7] build : add compile option to force use of MMQ kernels --- CMakeLists.txt | 7 +++++++ Makefile | 3 +++ ggml-cuda.cu | 1 + 3 files changed, 11 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 202f260491d39..d9fc86237b15c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,6 +82,7 @@ set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") option(LLAMA_CUBLAS "llama: use CUDA" OFF) #option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF) option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) +option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF) set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) @@ -305,6 +306,9 @@ if (LLAMA_CUBLAS) if (LLAMA_CUDA_FORCE_DMMV) add_compile_definitions(GGML_CUDA_FORCE_DMMV) endif() + if (LLAMA_CUDA_FORCE_MMQ) + add_compile_definitions(GGML_CUDA_FORCE_MMQ) + endif() add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) if (DEFINED LLAMA_CUDA_DMMV_Y) @@ -405,6 +409,9 @@ if (LLAMA_HIPBLAS) if (LLAMA_CUDA_FORCE_DMMV) target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV) endif() + if (LLAMA_CUDA_FORCE_MMQ) + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_MMQ) + endif() target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) diff --git a/Makefile b/Makefile index 80179631f95a5..68069f9ff331e 100644 --- a/Makefile +++ b/Makefile @@ -397,6 +397,9 @@ endif # CUDA_DOCKER_ARCH ifdef LLAMA_CUDA_FORCE_DMMV NVCCFLAGS += -DGGML_CUDA_FORCE_DMMV endif # LLAMA_CUDA_FORCE_DMMV +ifdef LLAMA_CUDA_FORCE_MMQ + NVCCFLAGS += -DGGML_CUDA_FORCE_MMQ +endif # LLAMA_CUDA_FORCE_MMQ ifdef LLAMA_CUDA_DMMV_X NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) else diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 558cd0bd8861a..1ba951f688d82 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -92,6 +92,7 @@ // for large computational tasks. the drawback is that this requires some extra amount of VRAM: // - 7B quantum model: +100-200 MB // - 13B quantum model: +200-400 MB +// //#define GGML_CUDA_FORCE_MMQ // TODO: improve this to be correct for more hardware