From 79b2da9529daaf5459aac02a5bc67befa9ba088e Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Wed, 20 Nov 2024 01:35:31 -0500 Subject: [PATCH] Fix a missing feature of sdpa in flash attention one. --- ...scaled_dot_product_attention_flash_attn.cu | 100 +++++++++++++++++- ...ccv_nnc_scaled_dot_product_attention_mps.m | 2 +- test/int/nnc/cublas.tests.c | 94 ++++++++++++++++ 3 files changed, 194 insertions(+), 2 deletions(-) diff --git a/lib/nnc/cmd/scaled_dot_product_attention/gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu b/lib/nnc/cmd/scaled_dot_product_attention/gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu index cce50cb95..e59daaf51 100644 --- a/lib/nnc/cmd/scaled_dot_product_attention/gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu +++ b/lib/nnc/cmd/scaled_dot_product_attention/gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu @@ -28,7 +28,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c if (bias) // bias always requires a weight matrix. { assert(weights); } - ccv_nnc_tensor_view_t* const saved_softmax_lse = (ccv_nnc_tensor_view_t*)outputs[1]; + ccv_nnc_tensor_view_t* const saved_softmax_lse = output_size > 1 ? (ccv_nnc_tensor_view_t*)outputs[1] : 0; ccv_nnc_tensor_view_t* const o = (weights) ? (ccv_nnc_tensor_view_t*)outputs[2] : (ccv_nnc_tensor_view_t*)outputs[0]; const int q_nd = ccv_nnc_tensor_nd(q->info.dim); assert(q_nd == 3 || q_nd == 4); @@ -201,6 +201,104 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c cudaStream_t stream = ccv_nnc_stream_context_get_stream(stream_context); run_mha_fwd(params, stream, false); CUDA_ENFORCE(cudaGetLastError()); + if (weights) + { + const ccv_nnc_tensor_view_t* a = o; + const ccv_nnc_tensor_view_t* w = weights; + ccv_nnc_tensor_view_t* b = (ccv_nnc_tensor_view_t*)outputs[0]; + assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // It is a 1-d array + assert(CCV_IS_TENSOR_CONTIGUOUS(b)); + const int b_nd = ccv_nnc_tensor_nd(b->info.dim); + assert(b_nd == 3); + int w_batch_size, w_rows, w_cols, w_batch_inc, w_rows_inc, w_cols_inc; + const int w_nd = ccv_nnc_tensor_nd(w->info.dim); + const int transpose_w[2] = { + w_nd - 2, w_nd - 1 + }; + ccv_nnc_tensor_get_matrix_params(w->info, CCV_IS_TENSOR_VIEW(w) ? w->stride : 0, w->info.dim, transpose_w, &w_batch_size, &w_rows, &w_cols, &w_batch_inc, &w_rows_inc, &w_cols_inc); + int a_rows, a_cols; + if (o_nd == 3) { + a_rows = odim[1] * odim[2]; + a_cols = odim[3]; + } else if (q_nd == 4) { + a_rows = odim[0] * odim[1]; + a_cols = odim[2] * odim[3]; + } + int b_rows, b_cols, b_rows_inc; + b_rows = b->info.dim[0] * b->info.dim[1]; + b_cols = b->info.dim[2]; + b_rows_inc = b_cols; + assert(a_rows == b_rows); + assert(a_cols == w_rows); + assert(w_cols == b_cols); + + const cublasOperation_t transa = CUBLAS_OP_T; + const cublasOperation_t transb = CUBLAS_OP_N; + const int lda_inc = w_cols_inc; + const int ldb_inc = a_cols; + size_t w_data_size = 0; + int w_datatype = w->info.datatype; + if (CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX) + { + ccv_nnc_tensor_param_t w_params = w->info; + w_datatype = (w_params.datatype & 0xff) << 12; + ccv_nnc_tensor_param_t depalettize_w_params = w_params; + depalettize_w_params.datatype = w_datatype; + depalettize_w_params.reserved = 0; + w_data_size = ccv_nnc_tensor_data_size(depalettize_w_params); + } + const size_t cublas_size = ccv_nnc_cublas_workspace_size_in_bytes(inputs, input_size, outputs, output_size); + void* workspace = 0; + if (w_data_size > 0) + workspace = ccv_nnc_stream_context_get_workspace(stream_context, cublas_size + w_data_size, CCV_TENSOR_GPU_MEMORY); + unsigned char* w_data = w->data.u8; + if (CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX) + { + ccv_nnc_tensor_param_t w_params = w->info; + const size_t count = ccv_nnc_tensor_count(w_params); + const int qbits = (w_params.datatype & 0xf00) >> 8; + const int number_in_blocks = w_params.reserved; + w_data = (unsigned char*)workspace + cublas_size; + ccv_nnc_compat_depalettize(w->data.u8, w_datatype, ccv_nnc_tensor_data_size_without_padding(w_params), qbits, number_in_blocks, w_data, count, stream_context); + } + cublasHandle_t cublas = ccv_nnc_stream_context_get_cublas(stream_context); + static const half one_f16 = 1; + static const float one_f32 = 1; + static const double one_f64 = 1; + static const double zero_f64 = 0; + const void* zero = &zero_f64; + const void* one; + switch (ccv_nnc_cuda_compute_datatype(b->info.datatype)) + { + case CUBLAS_COMPUTE_16F: + one = &one_f16; + break; + case CUBLAS_COMPUTE_32F: + case CUBLAS_COMPUTE_32F_FAST_TF32: + one = &one_f32; + break; + case CUBLAS_COMPUTE_64F: + one = &one_f64; + break; + default: + assert(0); + } + ccv_nnc_stream_context_set_cublas_workspace(cublas, stream_context, cublas_size); + if (bias) + { + int bias_batch_size, bias_rows, bias_cols, bias_batch_inc, bias_rows_inc, bias_cols_inc; + const static int no_transpose[2] = {}; + ccv_nnc_tensor_get_matrix_params(bias->info, CCV_IS_TENSOR_VIEW(bias) ? bias->stride : 0, bias->info.dim, no_transpose, &bias_batch_size, &bias_rows, &bias_cols, &bias_batch_inc, &bias_rows_inc, &bias_cols_inc); + assert(bias_batch_size == 1); + assert(bias_cols == b_cols); + assert(CCV_IS_TENSOR_CONTIGUOUS(bias)); + const void* const device_ones = ccv_nnc_stream_context_get_ones(stream_context, b_rows, b->info.datatype); + CUBLAS_ENFORCE(cublasGemmEx(cublas, CUBLAS_OP_N, CUBLAS_OP_N, b_cols, b_rows, 1, one, bias->data.u8, ccv_nnc_cuda_datatype(bias->info.datatype), bias_rows_inc, device_ones, ccv_nnc_cuda_datatype(b->info.datatype), 1, zero, b->data.u8, ccv_nnc_cuda_datatype(b->info.datatype), b_rows_inc, ccv_nnc_cuda_compute_datatype(b->info.datatype), CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + CUBLAS_ENFORCE(cublasGemmEx(cublas, transa, transb, b_cols, b_rows, a_cols, one, w_data, ccv_nnc_cuda_datatype(w_datatype), lda_inc, a->data.u8, ccv_nnc_cuda_datatype(a->info.datatype), ldb_inc, one, b->data.u8, ccv_nnc_cuda_datatype(b->info.datatype), b_rows_inc, ccv_nnc_cuda_compute_datatype(b->info.datatype), CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + CUBLAS_ENFORCE(cublasGemmEx(cublas, transa, transb, b_cols, b_rows, a_cols, one, w_data, ccv_nnc_cuda_datatype(w_datatype), lda_inc, a->data.u8, ccv_nnc_cuda_datatype(a->info.datatype), ldb_inc, zero, b->data.u8, ccv_nnc_cuda_datatype(b->info.datatype), b_rows_inc, ccv_nnc_cuda_compute_datatype(b->info.datatype), CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } + } return CCV_NNC_EXEC_SUCCESS; } diff --git a/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m b/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m index c17d5c0d9..47c30888b 100644 --- a/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m +++ b/lib/nnc/cmd/scaled_dot_product_attention/mps/ccv_nnc_scaled_dot_product_attention_mps.m @@ -226,7 +226,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c const int b_nd = ccv_nnc_tensor_nd(weights->info.dim); assert(b_nd == 2); - assert(CCV_IS_TENSOR_CONTIGUOUS(bias)); + assert(CCV_IS_TENSOR_CONTIGUOUS(c)); const int c_nd = ccv_nnc_tensor_nd(c->info.dim); assert(c_nd == 3); diff --git a/test/int/nnc/cublas.tests.c b/test/int/nnc/cublas.tests.c index 8b8cedf2f..55e4b33de 100644 --- a/test/int/nnc/cublas.tests.c +++ b/test/int/nnc/cublas.tests.c @@ -2832,6 +2832,100 @@ TEST_CASE("scaled dot product attention with flash_attn") #undef num_trials } +TEST_CASE("scaled dot product attention + unify head with flash_attn") +{ + GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_GPU_REF)); + ccv_nnc_symbolic_graph_t* const sdp_symbolic_graph = ccv_nnc_symbolic_graph_new(); + ccv_nnc_tensor_symbol_t q = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q"); + ccv_nnc_tensor_symbol_t k = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k"); + ccv_nnc_tensor_symbol_t v = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "v"); + ccv_nnc_tensor_symbol_t w = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 512, 512), "w"); + ccv_nnc_tensor_symbol_t bias = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 512), "bias"); + ccv_nnc_tensor_symbol_t c = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "c"); + ccv_nnc_tensor_symbol_t r = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 512), "r"); + ccv_nnc_graph_exec_symbol_new(sdp_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(q, k, v, NO_TENSOR_SYMBOL, w, bias), TENSOR_SYMBOL_LIST(r, NO_TENSOR_SYMBOL, c), "scaled_dot_product_attention"); + ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); + ccv_nnc_graph_t* sdp_graph = 0; + ccv_nnc_tensor_arena_t* sdp_tensor_arena = 0; + ccv_nnc_graph_exec_arena_t* sdp_graph_exec_arena = 0; + ccv_nnc_symbolic_graph_compile(sdp_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph), &sdp_graph, &sdp_tensor_arena, &sdp_graph_exec_arena); + ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, q); + ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, k); + ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, v); + ccv_nnc_tensor_t* const w_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, w); + ccv_nnc_tensor_t* const bias_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bias); + dsfmt_t dsfmt; + int i; + dsfmt_init_gen_rand(&dsfmt, 1); + for (i = 0; i < 32 * 8 * 128 * 64; i++) + q_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + for (i = 0; i < 32 * 8 * 128 * 64; i++) + k_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + for (i = 0; i < 32 * 8 * 128 * 64; i++) + v_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + for (i = 0; i < 512 * 512; i++) + w_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) / sqrtf(512); + for (i = 0; i < 512; i++) + bias_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + ccv_nnc_tensor_t* const q_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 32, 128, 8, 64), 0); + ccv_nnc_tensor_t* const k_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 32, 128, 8, 64), 0); + ccv_nnc_tensor_t* const v_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 32, 128, 8, 64), 0); + ccv_nnc_tensor_t* const w_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 512, 512), 0); + ccv_nnc_tensor_t* const bias_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 512), 0); + ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor, k_tensor, v_tensor, w_tensor, bias_tensor), TENSOR_LIST(q_tensor_f16, k_tensor_f16, v_tensor_f16, w_tensor_f16, bias_tensor_f16), 0); + ccv_nnc_symbolic_graph_t* const g_symbolic_graph = ccv_nnc_symbolic_graph_new(); + ccv_nnc_tensor_symbol_t gq = ccv_nnc_tensor_symbol_new(g_symbolic_graph, GPU_TENSOR_NHWC(000, 16F, 32, 128, 8, 64), "q"); + ccv_nnc_tensor_symbol_t gk = ccv_nnc_tensor_symbol_new(g_symbolic_graph, GPU_TENSOR_NHWC(000, 16F, 32, 128, 8, 64), "k"); + ccv_nnc_tensor_symbol_t gv = ccv_nnc_tensor_symbol_new(g_symbolic_graph, GPU_TENSOR_NHWC(000, 16F, 32, 128, 8, 64), "v"); + ccv_nnc_tensor_symbol_t gw = ccv_nnc_tensor_symbol_new(g_symbolic_graph, GPU_TENSOR_NHWC(000, 16F, 512, 512), "w"); + ccv_nnc_tensor_symbol_t gbias = ccv_nnc_tensor_symbol_new(g_symbolic_graph, GPU_TENSOR_NHWC(000, 16F, 512), "bias"); + ccv_nnc_tensor_symbol_t gc = ccv_nnc_tensor_symbol_new(g_symbolic_graph, GPU_TENSOR_NHWC(000, 16F, 32, 128, 8, 64), "c"); + ccv_nnc_tensor_symbol_t gr = ccv_nnc_tensor_symbol_new(g_symbolic_graph, GPU_TENSOR_NHWC(000, 16F, 32, 128, 512), "r"); + ccv_nnc_graph_exec_symbol_new(g_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(gq, gk, gv, NO_TENSOR_SYMBOL, gw, gbias), TENSOR_SYMBOL_LIST(gr, NO_TENSOR_SYMBOL, gc), "scaled_dot_product_attention"); + ccv_nnc_graph_exec_symbol_autogen(g_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS); + ccv_nnc_graph_t* g_graph = 0; + ccv_nnc_tensor_arena_t* g_tensor_arena = 0; + ccv_nnc_graph_exec_arena_t* g_graph_exec_arena = 0; + ccv_nnc_symbolic_graph_compile(g_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(g_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(g_symbolic_graph), &g_graph, &g_tensor_arena, &g_graph_exec_arena); + ccv_nnc_tensor_t* const gq_tensor = ccv_nnc_tensor_from_symbol(g_tensor_arena, gq); + ccv_nnc_tensor_t* const gk_tensor = ccv_nnc_tensor_from_symbol(g_tensor_arena, gk); + ccv_nnc_tensor_t* const gv_tensor = ccv_nnc_tensor_from_symbol(g_tensor_arena, gv); + ccv_nnc_tensor_t* const gw_tensor = ccv_nnc_tensor_from_symbol(g_tensor_arena, gw); + ccv_nnc_tensor_t* const gbias_tensor = ccv_nnc_tensor_from_symbol(g_tensor_arena, gbias); + ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor_f16, k_tensor_f16, v_tensor_f16, w_tensor_f16, bias_tensor_f16), TENSOR_LIST(gq_tensor, gk_tensor, gv_tensor, gw_tensor, gbias_tensor), 0); + ccv_nnc_graph_run(sdp_graph, 0, TRAVERSE_FULL, 0, 0); + ccv_nnc_graph_run(g_graph, 0, TRAVERSE_FULL, 0, 0); + ccv_nnc_tensor_t* const r_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, r); + ccv_nnc_tensor_t* const o_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, c); + ccv_nnc_tensor_t* const gc_tensor = ccv_nnc_tensor_from_symbol(g_tensor_arena, gc); + ccv_nnc_tensor_t* const gr_tensor = ccv_nnc_tensor_from_symbol(g_tensor_arena, gr); + ccv_nnc_tensor_t* const ho_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 32, 128, 8, 64), 0); + ccv_nnc_tensor_t* const hr_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 32, 128, 512), 0); + ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gc_tensor, gr_tensor), TENSOR_LIST(ho_f16, hr_f16), 0); + ccv_nnc_tensor_t* const ho = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), 0); + ccv_nnc_tensor_t* const hr = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 512), 0); + ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(ho_f16, hr_f16), TENSOR_LIST(ho, hr), 0); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, o_tensor->data.f32, ho->data.f32, 32 * 128 * 8 * 64, 3e-3, "graph computed result should match scaled dot product attention op result"); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, r_tensor->data.f32, hr->data.f32, 32 * 128 * 512, 3e-2, "graph computed result should match scaled dot product attention op result"); + ccv_nnc_symbolic_graph_free(sdp_symbolic_graph); + ccv_nnc_tensor_arena_free(sdp_tensor_arena); + ccv_nnc_graph_exec_arena_free(sdp_graph_exec_arena); + ccv_nnc_graph_free(sdp_graph); + ccv_nnc_symbolic_graph_free(g_symbolic_graph); + ccv_nnc_tensor_arena_free(g_tensor_arena); + ccv_nnc_graph_exec_arena_free(g_graph_exec_arena); + ccv_nnc_graph_free(g_graph); + ccv_nnc_tensor_free(ho); + ccv_nnc_tensor_free(hr); + ccv_nnc_tensor_free(ho_f16); + ccv_nnc_tensor_free(hr_f16); + ccv_nnc_tensor_free(q_tensor_f16); + ccv_nnc_tensor_free(k_tensor_f16); + ccv_nnc_tensor_free(v_tensor_f16); + ccv_nnc_tensor_free(w_tensor_f16); + ccv_nnc_tensor_free(bias_tensor_f16); +} + TEST_CASE("scaled dot product attention gradient with flash_attn") { GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_GPU_REF) &&