Skip to content

Commit

Permalink
Fix a missing feature of sdpa in flash attention one.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Nov 20, 2024
1 parent f376a50 commit 79b2da9
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
if (bias) // bias always requires a weight matrix.
{ assert(weights); }

ccv_nnc_tensor_view_t* const saved_softmax_lse = (ccv_nnc_tensor_view_t*)outputs[1];
ccv_nnc_tensor_view_t* const saved_softmax_lse = output_size > 1 ? (ccv_nnc_tensor_view_t*)outputs[1] : 0;
ccv_nnc_tensor_view_t* const o = (weights) ? (ccv_nnc_tensor_view_t*)outputs[2] : (ccv_nnc_tensor_view_t*)outputs[0];
const int q_nd = ccv_nnc_tensor_nd(q->info.dim);
assert(q_nd == 3 || q_nd == 4);
Expand Down Expand Up @@ -201,6 +201,104 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
cudaStream_t stream = ccv_nnc_stream_context_get_stream(stream_context);
run_mha_fwd(params, stream, false);
CUDA_ENFORCE(cudaGetLastError());
if (weights)
{
const ccv_nnc_tensor_view_t* a = o;
const ccv_nnc_tensor_view_t* w = weights;
ccv_nnc_tensor_view_t* b = (ccv_nnc_tensor_view_t*)outputs[0];
assert(!bias || (bias->info.dim[1] == 0 || bias->info.dim[2] == 0 || bias->info.dim[3] == 0)); // It is a 1-d array
assert(CCV_IS_TENSOR_CONTIGUOUS(b));
const int b_nd = ccv_nnc_tensor_nd(b->info.dim);
assert(b_nd == 3);
int w_batch_size, w_rows, w_cols, w_batch_inc, w_rows_inc, w_cols_inc;
const int w_nd = ccv_nnc_tensor_nd(w->info.dim);
const int transpose_w[2] = {
w_nd - 2, w_nd - 1
};
ccv_nnc_tensor_get_matrix_params(w->info, CCV_IS_TENSOR_VIEW(w) ? w->stride : 0, w->info.dim, transpose_w, &w_batch_size, &w_rows, &w_cols, &w_batch_inc, &w_rows_inc, &w_cols_inc);
int a_rows, a_cols;
if (o_nd == 3) {
a_rows = odim[1] * odim[2];
a_cols = odim[3];
} else if (q_nd == 4) {
a_rows = odim[0] * odim[1];
a_cols = odim[2] * odim[3];
}
int b_rows, b_cols, b_rows_inc;
b_rows = b->info.dim[0] * b->info.dim[1];
b_cols = b->info.dim[2];
b_rows_inc = b_cols;
assert(a_rows == b_rows);
assert(a_cols == w_rows);
assert(w_cols == b_cols);

const cublasOperation_t transa = CUBLAS_OP_T;
const cublasOperation_t transb = CUBLAS_OP_N;
const int lda_inc = w_cols_inc;
const int ldb_inc = a_cols;
size_t w_data_size = 0;
int w_datatype = w->info.datatype;
if (CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX)
{
ccv_nnc_tensor_param_t w_params = w->info;
w_datatype = (w_params.datatype & 0xff) << 12;
ccv_nnc_tensor_param_t depalettize_w_params = w_params;
depalettize_w_params.datatype = w_datatype;
depalettize_w_params.reserved = 0;
w_data_size = ccv_nnc_tensor_data_size(depalettize_w_params);
}
const size_t cublas_size = ccv_nnc_cublas_workspace_size_in_bytes(inputs, input_size, outputs, output_size);
void* workspace = 0;
if (w_data_size > 0)
workspace = ccv_nnc_stream_context_get_workspace(stream_context, cublas_size + w_data_size, CCV_TENSOR_GPU_MEMORY);
unsigned char* w_data = w->data.u8;
if (CCV_GET_DATA_TYPE(w->info.datatype) == CCV_QX)
{
ccv_nnc_tensor_param_t w_params = w->info;
const size_t count = ccv_nnc_tensor_count(w_params);
const int qbits = (w_params.datatype & 0xf00) >> 8;
const int number_in_blocks = w_params.reserved;
w_data = (unsigned char*)workspace + cublas_size;
ccv_nnc_compat_depalettize(w->data.u8, w_datatype, ccv_nnc_tensor_data_size_without_padding(w_params), qbits, number_in_blocks, w_data, count, stream_context);
}
cublasHandle_t cublas = ccv_nnc_stream_context_get_cublas(stream_context);
static const half one_f16 = 1;
static const float one_f32 = 1;
static const double one_f64 = 1;
static const double zero_f64 = 0;
const void* zero = &zero_f64;
const void* one;
switch (ccv_nnc_cuda_compute_datatype(b->info.datatype))
{
case CUBLAS_COMPUTE_16F:
one = &one_f16;
break;
case CUBLAS_COMPUTE_32F:
case CUBLAS_COMPUTE_32F_FAST_TF32:
one = &one_f32;
break;
case CUBLAS_COMPUTE_64F:
one = &one_f64;
break;
default:
assert(0);
}
ccv_nnc_stream_context_set_cublas_workspace(cublas, stream_context, cublas_size);
if (bias)
{
int bias_batch_size, bias_rows, bias_cols, bias_batch_inc, bias_rows_inc, bias_cols_inc;
const static int no_transpose[2] = {};
ccv_nnc_tensor_get_matrix_params(bias->info, CCV_IS_TENSOR_VIEW(bias) ? bias->stride : 0, bias->info.dim, no_transpose, &bias_batch_size, &bias_rows, &bias_cols, &bias_batch_inc, &bias_rows_inc, &bias_cols_inc);
assert(bias_batch_size == 1);
assert(bias_cols == b_cols);
assert(CCV_IS_TENSOR_CONTIGUOUS(bias));
const void* const device_ones = ccv_nnc_stream_context_get_ones(stream_context, b_rows, b->info.datatype);
CUBLAS_ENFORCE(cublasGemmEx(cublas, CUBLAS_OP_N, CUBLAS_OP_N, b_cols, b_rows, 1, one, bias->data.u8, ccv_nnc_cuda_datatype(bias->info.datatype), bias_rows_inc, device_ones, ccv_nnc_cuda_datatype(b->info.datatype), 1, zero, b->data.u8, ccv_nnc_cuda_datatype(b->info.datatype), b_rows_inc, ccv_nnc_cuda_compute_datatype(b->info.datatype), CUBLAS_GEMM_DEFAULT_TENSOR_OP));
CUBLAS_ENFORCE(cublasGemmEx(cublas, transa, transb, b_cols, b_rows, a_cols, one, w_data, ccv_nnc_cuda_datatype(w_datatype), lda_inc, a->data.u8, ccv_nnc_cuda_datatype(a->info.datatype), ldb_inc, one, b->data.u8, ccv_nnc_cuda_datatype(b->info.datatype), b_rows_inc, ccv_nnc_cuda_compute_datatype(b->info.datatype), CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
CUBLAS_ENFORCE(cublasGemmEx(cublas, transa, transb, b_cols, b_rows, a_cols, one, w_data, ccv_nnc_cuda_datatype(w_datatype), lda_inc, a->data.u8, ccv_nnc_cuda_datatype(a->info.datatype), ldb_inc, zero, b->data.u8, ccv_nnc_cuda_datatype(b->info.datatype), b_rows_inc, ccv_nnc_cuda_compute_datatype(b->info.datatype), CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}
return CCV_NNC_EXEC_SUCCESS;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c

const int b_nd = ccv_nnc_tensor_nd(weights->info.dim);
assert(b_nd == 2);
assert(CCV_IS_TENSOR_CONTIGUOUS(bias));
assert(CCV_IS_TENSOR_CONTIGUOUS(c));
const int c_nd = ccv_nnc_tensor_nd(c->info.dim);
assert(c_nd == 3);

Expand Down
94 changes: 94 additions & 0 deletions test/int/nnc/cublas.tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -2832,6 +2832,100 @@ TEST_CASE("scaled dot product attention with flash_attn")
#undef num_trials
}

TEST_CASE("scaled dot product attention + unify head with flash_attn")
{
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_GPU_REF));
ccv_nnc_symbolic_graph_t* const sdp_symbolic_graph = ccv_nnc_symbolic_graph_new();
ccv_nnc_tensor_symbol_t q = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "q");
ccv_nnc_tensor_symbol_t k = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "k");
ccv_nnc_tensor_symbol_t v = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "v");
ccv_nnc_tensor_symbol_t w = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 512, 512), "w");
ccv_nnc_tensor_symbol_t bias = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 512), "bias");
ccv_nnc_tensor_symbol_t c = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), "c");
ccv_nnc_tensor_symbol_t r = ccv_nnc_tensor_symbol_new(sdp_symbolic_graph, CPU_TENSOR_NHWC(32F, 32, 128, 512), "r");
ccv_nnc_graph_exec_symbol_new(sdp_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(q, k, v, NO_TENSOR_SYMBOL, w, bias), TENSOR_SYMBOL_LIST(r, NO_TENSOR_SYMBOL, c), "scaled_dot_product_attention");
ccv_nnc_graph_exec_symbol_autogen(sdp_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
ccv_nnc_graph_t* sdp_graph = 0;
ccv_nnc_tensor_arena_t* sdp_tensor_arena = 0;
ccv_nnc_graph_exec_arena_t* sdp_graph_exec_arena = 0;
ccv_nnc_symbolic_graph_compile(sdp_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(sdp_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(sdp_symbolic_graph), &sdp_graph, &sdp_tensor_arena, &sdp_graph_exec_arena);
ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, q);
ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, k);
ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, v);
ccv_nnc_tensor_t* const w_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, w);
ccv_nnc_tensor_t* const bias_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, bias);
dsfmt_t dsfmt;
int i;
dsfmt_init_gen_rand(&dsfmt, 1);
for (i = 0; i < 32 * 8 * 128 * 64; i++)
q_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
for (i = 0; i < 32 * 8 * 128 * 64; i++)
k_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
for (i = 0; i < 32 * 8 * 128 * 64; i++)
v_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
for (i = 0; i < 512 * 512; i++)
w_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt) / sqrtf(512);
for (i = 0; i < 512; i++)
bias_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt);
ccv_nnc_tensor_t* const q_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 32, 128, 8, 64), 0);
ccv_nnc_tensor_t* const k_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 32, 128, 8, 64), 0);
ccv_nnc_tensor_t* const v_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 32, 128, 8, 64), 0);
ccv_nnc_tensor_t* const w_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 512, 512), 0);
ccv_nnc_tensor_t* const bias_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 512), 0);
ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor, k_tensor, v_tensor, w_tensor, bias_tensor), TENSOR_LIST(q_tensor_f16, k_tensor_f16, v_tensor_f16, w_tensor_f16, bias_tensor_f16), 0);
ccv_nnc_symbolic_graph_t* const g_symbolic_graph = ccv_nnc_symbolic_graph_new();
ccv_nnc_tensor_symbol_t gq = ccv_nnc_tensor_symbol_new(g_symbolic_graph, GPU_TENSOR_NHWC(000, 16F, 32, 128, 8, 64), "q");
ccv_nnc_tensor_symbol_t gk = ccv_nnc_tensor_symbol_new(g_symbolic_graph, GPU_TENSOR_NHWC(000, 16F, 32, 128, 8, 64), "k");
ccv_nnc_tensor_symbol_t gv = ccv_nnc_tensor_symbol_new(g_symbolic_graph, GPU_TENSOR_NHWC(000, 16F, 32, 128, 8, 64), "v");
ccv_nnc_tensor_symbol_t gw = ccv_nnc_tensor_symbol_new(g_symbolic_graph, GPU_TENSOR_NHWC(000, 16F, 512, 512), "w");
ccv_nnc_tensor_symbol_t gbias = ccv_nnc_tensor_symbol_new(g_symbolic_graph, GPU_TENSOR_NHWC(000, 16F, 512), "bias");
ccv_nnc_tensor_symbol_t gc = ccv_nnc_tensor_symbol_new(g_symbolic_graph, GPU_TENSOR_NHWC(000, 16F, 32, 128, 8, 64), "c");
ccv_nnc_tensor_symbol_t gr = ccv_nnc_tensor_symbol_new(g_symbolic_graph, GPU_TENSOR_NHWC(000, 16F, 32, 128, 512), "r");
ccv_nnc_graph_exec_symbol_new(g_symbolic_graph, CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(1.0 / 8, 0), TENSOR_SYMBOL_LIST(gq, gk, gv, NO_TENSOR_SYMBOL, gw, gbias), TENSOR_SYMBOL_LIST(gr, NO_TENSOR_SYMBOL, gc), "scaled_dot_product_attention");
ccv_nnc_graph_exec_symbol_autogen(g_symbolic_graph, 0, 0, CCV_NNC_AUTOGEN_ALL_EXECS | CCV_NNC_AUTOGEN_SOURCES_AND_DESTINATIONS);
ccv_nnc_graph_t* g_graph = 0;
ccv_nnc_tensor_arena_t* g_tensor_arena = 0;
ccv_nnc_graph_exec_arena_t* g_graph_exec_arena = 0;
ccv_nnc_symbolic_graph_compile(g_symbolic_graph, ccv_nnc_default_compile_params, 0, 0, 0, 0, SYMBOLIC_GRAPH_SOURCES(g_symbolic_graph), SYMBOLIC_GRAPH_DESTINATIONS(g_symbolic_graph), &g_graph, &g_tensor_arena, &g_graph_exec_arena);
ccv_nnc_tensor_t* const gq_tensor = ccv_nnc_tensor_from_symbol(g_tensor_arena, gq);
ccv_nnc_tensor_t* const gk_tensor = ccv_nnc_tensor_from_symbol(g_tensor_arena, gk);
ccv_nnc_tensor_t* const gv_tensor = ccv_nnc_tensor_from_symbol(g_tensor_arena, gv);
ccv_nnc_tensor_t* const gw_tensor = ccv_nnc_tensor_from_symbol(g_tensor_arena, gw);
ccv_nnc_tensor_t* const gbias_tensor = ccv_nnc_tensor_from_symbol(g_tensor_arena, gbias);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor_f16, k_tensor_f16, v_tensor_f16, w_tensor_f16, bias_tensor_f16), TENSOR_LIST(gq_tensor, gk_tensor, gv_tensor, gw_tensor, gbias_tensor), 0);
ccv_nnc_graph_run(sdp_graph, 0, TRAVERSE_FULL, 0, 0);
ccv_nnc_graph_run(g_graph, 0, TRAVERSE_FULL, 0, 0);
ccv_nnc_tensor_t* const r_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, r);
ccv_nnc_tensor_t* const o_tensor = ccv_nnc_tensor_from_symbol(sdp_tensor_arena, c);
ccv_nnc_tensor_t* const gc_tensor = ccv_nnc_tensor_from_symbol(g_tensor_arena, gc);
ccv_nnc_tensor_t* const gr_tensor = ccv_nnc_tensor_from_symbol(g_tensor_arena, gr);
ccv_nnc_tensor_t* const ho_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 32, 128, 8, 64), 0);
ccv_nnc_tensor_t* const hr_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, 32, 128, 512), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gc_tensor, gr_tensor), TENSOR_LIST(ho_f16, hr_f16), 0);
ccv_nnc_tensor_t* const ho = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 8, 64), 0);
ccv_nnc_tensor_t* const hr = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, 32, 128, 512), 0);
ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(ho_f16, hr_f16), TENSOR_LIST(ho, hr), 0);
REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, o_tensor->data.f32, ho->data.f32, 32 * 128 * 8 * 64, 3e-3, "graph computed result should match scaled dot product attention op result");
REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, r_tensor->data.f32, hr->data.f32, 32 * 128 * 512, 3e-2, "graph computed result should match scaled dot product attention op result");
ccv_nnc_symbolic_graph_free(sdp_symbolic_graph);
ccv_nnc_tensor_arena_free(sdp_tensor_arena);
ccv_nnc_graph_exec_arena_free(sdp_graph_exec_arena);
ccv_nnc_graph_free(sdp_graph);
ccv_nnc_symbolic_graph_free(g_symbolic_graph);
ccv_nnc_tensor_arena_free(g_tensor_arena);
ccv_nnc_graph_exec_arena_free(g_graph_exec_arena);
ccv_nnc_graph_free(g_graph);
ccv_nnc_tensor_free(ho);
ccv_nnc_tensor_free(hr);
ccv_nnc_tensor_free(ho_f16);
ccv_nnc_tensor_free(hr_f16);
ccv_nnc_tensor_free(q_tensor_f16);
ccv_nnc_tensor_free(k_tensor_f16);
ccv_nnc_tensor_free(v_tensor_f16);
ccv_nnc_tensor_free(w_tensor_f16);
ccv_nnc_tensor_free(bias_tensor_f16);
}

TEST_CASE("scaled dot product attention gradient with flash_attn")
{
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_GPU_REF) &&
Expand Down

0 comments on commit 79b2da9

Please sign in to comment.