Skip to content

Commit

Permalink
Add support for MFA backprop.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Oct 30, 2024
1 parent d69ee17 commit 18f1860
Show file tree
Hide file tree
Showing 6 changed files with 479 additions and 183 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
const int is_downcast = ((cmd.info.scaled_dot_product_attention.flags & CCV_NNC_GEMM_16F) && q->info.datatype == CCV_16F);
int attention_is_batched = (batch_size > 1);
ccv_nnc_mfa_attention_params_t params = {
.type = 0,
.data_type = mtl_data_type,
.R = (uint32_t)R,
.C = (uint32_t)C,
Expand Down Expand Up @@ -359,11 +360,14 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
static int _ccv_nnc_scaled_dot_product_attention_back(const ccv_nnc_cmd_t cmd, const ccv_nnc_hint_t hint, const int flags, ccv_nnc_tensor_t* const* const inputs, const int input_size, ccv_nnc_tensor_t* const* const outputs, const int output_size, ccv_nnc_stream_context_t* const stream_context)
{
assert(input_size >= 6);
assert(output_size >= 3);
assert(!cmd.info.scaled_dot_product_attention.is_causal);
ccv_nnc_tensor_view_t* const g = (ccv_nnc_tensor_view_t*)inputs[0];
ccv_nnc_tensor_view_t* const q = (ccv_nnc_tensor_view_t*)inputs[3];
ccv_nnc_tensor_view_t* const k = (ccv_nnc_tensor_view_t*)inputs[4];
ccv_nnc_tensor_view_t* const v = (ccv_nnc_tensor_view_t*)inputs[5];
ccv_nnc_tensor_view_t* const o = input_size > 9 ? (ccv_nnc_tensor_view_t*)inputs[9] : 0;
ccv_nnc_tensor_view_t* const lse = input_size > 10 ? (ccv_nnc_tensor_view_t*)inputs[10] : 0;
ccv_nnc_tensor_view_t* const dq = (ccv_nnc_tensor_view_t*)outputs[0];
ccv_nnc_tensor_view_t* const dk = (ccv_nnc_tensor_view_t*)outputs[1];
ccv_nnc_tensor_view_t* const dv = (ccv_nnc_tensor_view_t*)outputs[2];
Expand Down Expand Up @@ -442,77 +446,207 @@ static int _ccv_nnc_scaled_dot_product_attention_back(const ccv_nnc_cmd_t cmd, c
dvstride[0] = dvstride[1], dvstride[1] = dvstride[2], dvstride[2] = dvstride[3];
}
@autoreleasepool {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, 0, hint, flags, inputs, input_size, outputs, output_size);
int indices[4];
const int* gdim_r = gdim;
const int* gstride_r = gstride;
const int* qdim_r = qdim;
const int* qstride_r = qstride;
const int* kdim_r = kdim;
const int* kstride_r = kstride;
const int* vdim_r = vdim;
const int* vstride_r = vstride;
const float scale = cmd.info.scaled_dot_product_attention.scale;

MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_g;
MPSGraphTensor* mps_g = ccv_nnc_mps_graph_tensor_input(graph, g, gdim_r, gstride_r, &mps_input_g);
[inputTensors addObject:mps_input_g];
MPSGraphShapedType* mps_g_shape = ccv_nnc_mps_graph_tensor_input_shape(g, gdim_r, gstride_r);
[inputShapedTypes addObject:mps_g_shape];

MPSGraphTensor* mps_input_q;
MPSGraphTensor* mps_q = ccv_nnc_mps_graph_tensor_input(graph, q, qdim_r, qstride_r, &mps_input_q);
[inputTensors addObject:mps_input_q];
MPSGraphShapedType* mps_q_shape = ccv_nnc_mps_graph_tensor_input_shape(q, qdim_r, qstride_r);
[inputShapedTypes addObject:mps_q_shape];

MPSGraphTensor* mps_input_k;
MPSGraphTensor* mps_k = ccv_nnc_mps_graph_tensor_input(graph, k, kdim_r, kstride_r, &mps_input_k);
[inputTensors addObject:mps_input_k];
MPSGraphShapedType* mps_k_shape = ccv_nnc_mps_graph_tensor_input_shape(k, kdim_r, kstride_r);
[inputShapedTypes addObject:mps_k_shape];

MPSGraphTensor* mps_input_v;
MPSGraphTensor* mps_v = ccv_nnc_mps_graph_tensor_input(graph, v, vdim_r, vstride_r, &mps_input_v);
[inputTensors addObject:mps_input_v];
MPSGraphShapedType* mps_v_shape = ccv_nnc_mps_graph_tensor_input_shape(v, vdim_r, vstride_r);
[inputShapedTypes addObject:mps_v_shape];

MPSGraphTensor* mps_scale = [graph constantWithScalar:scale dataType:ccv_nnc_mps_datatype(q->info.datatype)];
mps_q = [graph multiplicationWithPrimaryTensor:mps_scale secondaryTensor:[graph transposeTensor:mps_q dimension:1 withDimension:2 name:nil] name:nil];
mps_k = [graph transposeTensor:mps_k dimension:1 withDimension:2 name:nil];
MPSGraphTensor* mps_kt = [graph transposeTensor:mps_k dimension:2 withDimension:3 name:nil];
mps_v = [graph transposeTensor:mps_v dimension:1 withDimension:2 name:nil];
MPSGraphTensor* mps_qk = [graph matrixMultiplicationWithPrimaryTensor:mps_q secondaryTensor:mps_kt name:nil];
MPSGraphTensor* mps_softmax = [graph softMaxWithTensor:mps_qk axis:3 name:nil];
mps_g = [graph transposeTensor:mps_g dimension:1 withDimension:2 name:nil];
MPSGraphTensor* mps_softmaxt = [graph transposeTensor:mps_softmax dimension:2 withDimension:3 name:nil];
MPSGraphTensor* mps_dv = [graph matrixMultiplicationWithPrimaryTensor:mps_softmaxt secondaryTensor:mps_g name:nil];
mps_v = [graph transposeTensor:mps_v dimension:2 withDimension:3 name:nil];
MPSGraphTensor* mps_dsoftmax = [graph matrixMultiplicationWithPrimaryTensor:mps_g secondaryTensor:mps_v name:nil];
MPSGraphTensor* mulTensor = [graph multiplicationWithPrimaryTensor:mps_softmax secondaryTensor:mps_dsoftmax name:nil];
MPSGraphTensor* mulSumTensor = [graph reductionSumWithTensor:mulTensor axis:-1 name:nil];
MPSGraphTensor* gradSubTensor = [graph subtractionWithPrimaryTensor:mps_dsoftmax secondaryTensor:mulSumTensor name:nil];
MPSGraphTensor* mps_dqk = [graph multiplicationWithPrimaryTensor:mps_softmax secondaryTensor:gradSubTensor name:nil];
MPSGraphTensor* mps_dq = [graph multiplicationWithPrimaryTensor:mps_scale secondaryTensor:[graph matrixMultiplicationWithPrimaryTensor:mps_dqk secondaryTensor:mps_k name:nil] name:nil];
mps_dqk = [graph transposeTensor:mps_dqk dimension:2 withDimension:3 name:nil];
MPSGraphTensor* mps_dk = [graph matrixMultiplicationWithPrimaryTensor:mps_dqk secondaryTensor:mps_q name:nil];
mps_dq = [graph transposeTensor:mps_dq dimension:1 withDimension:2 name:nil];
[resultTensors addObject:mps_dq];
mps_dk = [graph transposeTensor:mps_dk dimension:1 withDimension:2 name:nil];
[resultTensors addObject:mps_dk];
mps_dv = [graph transposeTensor:mps_dv dimension:1 withDimension:2 name:nil];
[resultTensors addObject:mps_dv];
});
MPSGraphTensorData* data_g = ccv_nnc_mps_graph_tensor_data(g, gdim, gstride);
MPSGraphTensorData* data_q = ccv_nnc_mps_graph_tensor_data(q, qdim, qstride);
MPSGraphTensorData* data_k = ccv_nnc_mps_graph_tensor_data(k, kdim, kstride);
MPSGraphTensorData* data_v = ccv_nnc_mps_graph_tensor_data(v, vdim, vstride);
MPSGraphTensorData* data[] = {data_g, data_q, data_k, data_v};
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]], data[indices[2]], data[indices[3]]], (ccv_nnc_tensor_view_t*[]){ dq, dk, dv }, (int*[]){ dqdim, dkdim, dvdim }, (int*[]){ dqstride, dkstride, dvstride }, 3, 0);
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
ccv_nnc_mfa_context_t* context = ccv_nnc_default_mfa_context();
const int is_contiguous =
(!CCV_IS_TENSOR_VIEW(q) || ccv_nnc_tensor_view_is_contiguous(qdim, qstride)) &&
(!CCV_IS_TENSOR_VIEW(k) || ccv_nnc_tensor_view_is_contiguous(kdim, kstride)) &&
(!CCV_IS_TENSOR_VIEW(v) || ccv_nnc_tensor_view_is_contiguous(vdim, vstride)) &&
(!CCV_IS_TENSOR_VIEW(g) || ccv_nnc_tensor_view_is_contiguous(gdim, gstride)) &&
(!CCV_IS_TENSOR_VIEW(dq) || ccv_nnc_tensor_view_is_contiguous(dqdim, dqstride)) &&
(!CCV_IS_TENSOR_VIEW(dk) || ccv_nnc_tensor_view_is_contiguous(dkdim, dkstride)) &&
(!CCV_IS_TENSOR_VIEW(dv) || ccv_nnc_tensor_view_is_contiguous(dvdim, dvstride)) &&
(o ? (!CCV_IS_TENSOR_VIEW(o) || ccv_nnc_tensor_view_is_contiguous(o->info.dim, o->stride)) : 1) &&
(lse ? (!CCV_IS_TENSOR_VIEW(lse) || ccv_nnc_tensor_view_is_contiguous(lse->info.dim, lse->stride)) : 1);
const int is_same_dtype =
(q->info.datatype == k->info.datatype) &&
(k->info.datatype == v->info.datatype) &&
(v->info.datatype == g->info.datatype) &&
(g->info.datatype == dq->info.datatype) &&
(dq->info.datatype == dk->info.datatype) &&
(dk->info.datatype == dv->info.datatype) &&
(o ? (g->info.datatype == o->info.datatype) : 1);
const int is_supported_dtype = q->info.datatype == CCV_16F || q->info.datatype == CCV_32F;
uint32_t mtl_data_type = UINT32_MAX;
switch (q->info.datatype) {
case CCV_16F: {
mtl_data_type = 16;
break;
}
case CCV_32F: {
mtl_data_type = 3;
break;
}
default: {
assert(false);
break;
}
}
int batch_size;
int R;
int C;
int Hq;
int Hk;
int D;
if (q_nd == 3) {
batch_size = qdim[1];
assert(batch_size == kdim[1]);
R = qdim[2];
C = kdim[2];
Hq = 1;
Hk = 1;
D = qdim[3];
assert(D == kdim[3]);
} else if (q_nd == 4) {
batch_size = qdim[0];
assert(batch_size == kdim[0]);
R = qdim[1];
C = kdim[1];
Hq = qdim[2];
Hk = kdim[2];
assert(Hq >= Hk);
assert((Hq % Hk) == 0);
D = qdim[3];
assert(D == kdim[3]);
}
const int is_mfa_supported =
ccv_nnc_mfa_context_supported(context) && is_contiguous && is_same_dtype && is_supported_dtype && !(ccv_nnc_flags() & CCV_NNC_DISABLE_METAL_FLASH_ATTENTION) && (Hq == Hk);
if (is_mfa_supported)
{
const int o_nd = ccv_nnc_tensor_nd(o->info.dim);
assert(o_nd == 3 || o_nd == 4);
int odim[CCV_NNC_MAX_DIM_ALLOC];
ccv_nnc_tensor_view_get_dim(o, odim);
if (o_nd == 3)
odim[0] = odim[1], odim[1] = odim[2], odim[2] = 1;
const int is_downcast = ((cmd.info.scaled_dot_product_attention.flags & CCV_NNC_GEMM_16F) && q->info.datatype == CCV_16F);
int attention_is_batched = (batch_size > 1);
ccv_nnc_mfa_attention_params_t params = {
.type = 1, // 1 is backward query / key value.
.data_type = mtl_data_type,
.R = (uint32_t)R,
.C = (uint32_t)C,
.Hq = (uint32_t)Hq,
.Hk = (uint32_t)Hk,
.D = (uint32_t)D,
.Q_trans = false,
.K_trans = false,
.V_trans = false,
.O_trans = false,
.alpha = cmd.info.scaled_dot_product_attention.scale,
.batched = (attention_is_batched ? 1 : 0),
.masked = 0,
.upcast = !is_downcast,

.batch_dims_q = { 0 },
.batch_dims_mask = { 0 },
};
if (attention_is_batched) {
params.batch_dims_q[0] = batch_size;
params.batch_dims_q[1] = 0;
params.batch_dims_mask[0] = batch_size;
params.batch_dims_mask[1] = 0;
}
ccv_nnc_mfa_prepare_attention(context, params);

mtl_command_batch_t* command_batch = ccv_nnc_stream_context_start_command_batch(stream_context);
mtl_buffer_t* tensors[10] = {
mpgetbuffer((ccv_nnc_tensor_t*)q),
mpgetbuffer((ccv_nnc_tensor_t*)k),
mpgetbuffer((ccv_nnc_tensor_t*)v),
mpgetbuffer((ccv_nnc_tensor_t*)o),
lse ? mpgetbuffer((ccv_nnc_tensor_t*)lse) : 0,
mpgetbuffer((ccv_nnc_tensor_t*)g),
mpgetbuffer((ccv_nnc_tensor_t*)dq),
mpgetbuffer((ccv_nnc_tensor_t*)dk),
mpgetbuffer((ccv_nnc_tensor_t*)dv),
NULL,
};
size_t tensor_offsets[9] = {
q->dataof,
k->dataof,
v->dataof,
o->dataof,
lse ? lse->dataof : 0,
g->dataof,
dq->dataof,
dk->dataof,
dv->dataof
};
ccv_nnc_mfa_encode_attention(context, params, command_batch, tensors, tensor_offsets);
ccv_nnc_stream_context_finish_command_batch(stream_context, command_batch);
} else {
MPSCommandBuffer* command_buffer = ccv_nnc_stream_context_start_mps_command_buffer(stream_context);
ccv_nnc_mps_graph_key_t key = ccv_nnc_mps_graph_key_new(cmd, 0, hint, flags, inputs, input_size, outputs, output_size);
int indices[4];
const int* gdim_r = gdim;
const int* gstride_r = gstride;
const int* qdim_r = qdim;
const int* qstride_r = qstride;
const int* kdim_r = kdim;
const int* kstride_r = kstride;
const int* vdim_r = vdim;
const int* vstride_r = vstride;
const float scale = cmd.info.scaled_dot_product_attention.scale;

MPSGraphExecutable* executable = ccv_nnc_mps_graph_executable_cache(key, indices, ^void (MPSGraph* graph, NSMutableArray<MPSGraphTensor*>* inputTensors, NSMutableArray<MPSGraphShapedType*>* inputShapedTypes, NSMutableArray<MPSGraphTensor*>* resultTensors) {
MPSGraphTensor* mps_input_g;
MPSGraphTensor* mps_g = ccv_nnc_mps_graph_tensor_input(graph, g, gdim_r, gstride_r, &mps_input_g);
[inputTensors addObject:mps_input_g];
MPSGraphShapedType* mps_g_shape = ccv_nnc_mps_graph_tensor_input_shape(g, gdim_r, gstride_r);
[inputShapedTypes addObject:mps_g_shape];

MPSGraphTensor* mps_input_q;
MPSGraphTensor* mps_q = ccv_nnc_mps_graph_tensor_input(graph, q, qdim_r, qstride_r, &mps_input_q);
[inputTensors addObject:mps_input_q];
MPSGraphShapedType* mps_q_shape = ccv_nnc_mps_graph_tensor_input_shape(q, qdim_r, qstride_r);
[inputShapedTypes addObject:mps_q_shape];

MPSGraphTensor* mps_input_k;
MPSGraphTensor* mps_k = ccv_nnc_mps_graph_tensor_input(graph, k, kdim_r, kstride_r, &mps_input_k);
[inputTensors addObject:mps_input_k];
MPSGraphShapedType* mps_k_shape = ccv_nnc_mps_graph_tensor_input_shape(k, kdim_r, kstride_r);
[inputShapedTypes addObject:mps_k_shape];

MPSGraphTensor* mps_input_v;
MPSGraphTensor* mps_v = ccv_nnc_mps_graph_tensor_input(graph, v, vdim_r, vstride_r, &mps_input_v);
[inputTensors addObject:mps_input_v];
MPSGraphShapedType* mps_v_shape = ccv_nnc_mps_graph_tensor_input_shape(v, vdim_r, vstride_r);
[inputShapedTypes addObject:mps_v_shape];

MPSGraphTensor* mps_scale = [graph constantWithScalar:scale dataType:ccv_nnc_mps_datatype(q->info.datatype)];
mps_q = [graph multiplicationWithPrimaryTensor:mps_scale secondaryTensor:[graph transposeTensor:mps_q dimension:1 withDimension:2 name:nil] name:nil];
mps_k = [graph transposeTensor:mps_k dimension:1 withDimension:2 name:nil];
MPSGraphTensor* mps_kt = [graph transposeTensor:mps_k dimension:2 withDimension:3 name:nil];
mps_v = [graph transposeTensor:mps_v dimension:1 withDimension:2 name:nil];
MPSGraphTensor* mps_qk = [graph matrixMultiplicationWithPrimaryTensor:mps_q secondaryTensor:mps_kt name:nil];
MPSGraphTensor* mps_softmax = [graph softMaxWithTensor:mps_qk axis:3 name:nil];
mps_g = [graph transposeTensor:mps_g dimension:1 withDimension:2 name:nil];
MPSGraphTensor* mps_softmaxt = [graph transposeTensor:mps_softmax dimension:2 withDimension:3 name:nil];
MPSGraphTensor* mps_dv = [graph matrixMultiplicationWithPrimaryTensor:mps_softmaxt secondaryTensor:mps_g name:nil];
mps_v = [graph transposeTensor:mps_v dimension:2 withDimension:3 name:nil];
MPSGraphTensor* mps_dsoftmax = [graph matrixMultiplicationWithPrimaryTensor:mps_g secondaryTensor:mps_v name:nil];
MPSGraphTensor* mulTensor = [graph multiplicationWithPrimaryTensor:mps_softmax secondaryTensor:mps_dsoftmax name:nil];
MPSGraphTensor* mulSumTensor = [graph reductionSumWithTensor:mulTensor axis:-1 name:nil];
MPSGraphTensor* gradSubTensor = [graph subtractionWithPrimaryTensor:mps_dsoftmax secondaryTensor:mulSumTensor name:nil];
MPSGraphTensor* mps_dqk = [graph multiplicationWithPrimaryTensor:mps_softmax secondaryTensor:gradSubTensor name:nil];
MPSGraphTensor* mps_dq = [graph multiplicationWithPrimaryTensor:mps_scale secondaryTensor:[graph matrixMultiplicationWithPrimaryTensor:mps_dqk secondaryTensor:mps_k name:nil] name:nil];
mps_dqk = [graph transposeTensor:mps_dqk dimension:2 withDimension:3 name:nil];
MPSGraphTensor* mps_dk = [graph matrixMultiplicationWithPrimaryTensor:mps_dqk secondaryTensor:mps_q name:nil];
mps_dq = [graph transposeTensor:mps_dq dimension:1 withDimension:2 name:nil];
[resultTensors addObject:mps_dq];
mps_dk = [graph transposeTensor:mps_dk dimension:1 withDimension:2 name:nil];
[resultTensors addObject:mps_dk];
mps_dv = [graph transposeTensor:mps_dv dimension:1 withDimension:2 name:nil];
[resultTensors addObject:mps_dv];
});
MPSGraphTensorData* data_g = ccv_nnc_mps_graph_tensor_data(g, gdim, gstride);
MPSGraphTensorData* data_q = ccv_nnc_mps_graph_tensor_data(q, qdim, qstride);
MPSGraphTensorData* data_k = ccv_nnc_mps_graph_tensor_data(k, kdim, kstride);
MPSGraphTensorData* data_v = ccv_nnc_mps_graph_tensor_data(v, vdim, vstride);
MPSGraphTensorData* data[] = {data_g, data_q, data_k, data_v};
ccv_nnc_mps_graph_executable_result(executable, command_buffer, @[data[indices[0]], data[indices[1]], data[indices[2]], data[indices[3]]], (ccv_nnc_tensor_view_t*[]){ dq, dk, dv }, (int*[]){ dqdim, dkdim, dvdim }, (int*[]){ dqstride, dkstride, dvstride }, 3, 0);
ccv_nnc_stream_context_finish_mps_command_buffer(stream_context, command_buffer);
}
}
return CCV_NNC_EXEC_SUCCESS;
}
Expand Down
Loading

0 comments on commit 18f1860

Please sign in to comment.