Skip to content

Commit

Permalink
Add CUDA implementation for attn_probs
Browse files Browse the repository at this point in the history
  • Loading branch information
amancini-N committed Dec 12, 2024
1 parent 3147d51 commit 239df8b
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 59 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
past_value == nullptr &&
present_k == nullptr &&
present_v == nullptr &&
attn_probs == nullptr && // TODO: can we support it?
attn_probs == nullptr && // TODO: can we support it?

Check warning on line 164 in onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc:164: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
l2_cache_size_ > 0) {
MlasFlashAttentionThreadedArgs args;
args.batch_size = batch_size;
Expand Down
27 changes: 17 additions & 10 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -590,15 +590,22 @@ Status UnfusedAttention(

DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length);

constexpr size_t element_size = sizeof(T);
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads,
sequence_length, total_sequence_length);
T* scratch2 = data.scratch + (bytes / element_size);
T* softmax_storage;
if (data.attn_probs == nullptr) {
constexpr size_t element_size = sizeof(T);
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads,
sequence_length, total_sequence_length);
T* scratch2 = data.scratch + (bytes / element_size);
softmax_storage = scratch2;
}
else {
softmax_storage = data.attn_probs;
}

const bool broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0;
const bool broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1;

// Apply softmax and store result R to scratch2: BxNxSxT
// Apply softmax and store result R to softmax_storage: BxNxSxT
if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
const int mask_dimension = static_cast<int>(mask_index_dims.size());

Expand All @@ -612,7 +619,7 @@ Status UnfusedAttention(
ComputeSoftmaxWithRawMask<T>(
ort_stream, total_sequence_length, sequence_length, batch_size, num_heads,
mask_index, nullptr, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension,
data.scratch, softmax_storage, parameters.is_unidirectional, scale, mask_dimension,
parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace,
parameters.mask_filter_value));
} else if (nullptr != mask_index) { // 1d mask index
Expand All @@ -622,24 +629,24 @@ Status UnfusedAttention(
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D<T>(
stream, total_sequence_length, sequence_length, batch_size, num_heads,
mask_index, mask_start, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
data.scratch, scratch2, parameters.is_unidirectional));
data.scratch, softmax_storage, parameters.is_unidirectional));
} else { // no mask
ORT_RETURN_IF_ERROR(
ComputeSoftmax<T>(
stream, total_sequence_length, sequence_length, batch_size, num_heads,
data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
data.scratch, scratch2, parameters.is_unidirectional));
data.scratch, softmax_storage, parameters.is_unidirectional));
}

DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length);
DUMP_TENSOR_D("Softmax", softmax_storage, batch_size, num_heads, sequence_length, total_sequence_length);

// compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
T* temp_output = data.q;
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
v_head_size, sequence_length, total_sequence_length,
&one, data.v, v_head_size, present_size_per_batch_v,
scratch2, total_sequence_length, sequence_length * total_sequence_length,
softmax_storage, total_sequence_length, sequence_length * total_sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));

// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ struct AttentionData {
T* present = nullptr;
T* present_key = nullptr;
T* present_value = nullptr;
T* attn_probs = nullptr;

void* fused_runner = nullptr;
const void* fused_cross_attention_kernel = nullptr;
Expand Down
15 changes: 13 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
output_shape[2] = static_cast<int64_t>(parameters.v_hidden_size);
Tensor* output = context->Output(0, output_shape);

TensorShapeVector attn_probs_shape(4);
attn_probs_shape[0] = static_cast<int64_t>(parameters.batch_size);
attn_probs_shape[1] = static_cast<int64_t>(parameters.num_heads);
attn_probs_shape[2] = static_cast<int64_t>(sequence_length);
attn_probs_shape[3] = static_cast<int64_t>(parameters.total_sequence_length);
Tensor* attn_probs = context->Output(3, attn_probs_shape);

std::vector<int64_t> present_dims{
parameters.batch_size, parameters.num_heads, parameters.total_sequence_length, parameters.head_size};
TensorShape present_shape(present_dims);
Expand Down Expand Up @@ -172,6 +179,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.past_sequence_length > 0 &&
nullptr == attention_bias &&
nullptr == key_padding_mask &&
nullptr == attn_probs && // TODO: support attn_probs

Check warning on line 182 in onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc:182: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
parameters.head_size == parameters.v_head_size &&
onnxruntime::lean::is_supported(device_prop,
parameters.head_size,
Expand Down Expand Up @@ -216,6 +224,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
!disable_flash_attention_ &&
nullptr == attention_bias &&
nullptr == key_padding_mask &&
nullptr == attn_probs && // TODO: support attn_probs

Check warning on line 227 in onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc:227: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
parameters.head_size == parameters.v_head_size &&
onnxruntime::flash::is_supported(device_prop,
parameters.head_size,
Expand Down Expand Up @@ -280,7 +289,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
!is_unidirectional_ &&
nullptr == key_padding_mask &&
nullptr == attention_bias &&
nullptr == past_key && nullptr == present_key &&
nullptr == past_key && nullptr == present_key && nullptr == attn_probs &&
(parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) &&
parameters.hidden_size == parameters.v_hidden_size &&
has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length);
Expand All @@ -305,7 +314,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
!is_unidirectional_ &&
nullptr == attention_bias &&
(parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) &&
nullptr == past_key && nullptr == present_key &&
nullptr == past_key && nullptr == present_key && nullptr == attn_probs &&
is_mask_none_or_1d_k_len &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner
Expand Down Expand Up @@ -339,6 +348,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
kernel_type == AttentionKernelType::AttentionKernel_Default &&
!disable_memory_efficient_attention_ &&
is_long_sequence &&
nullptr == attn_probs && // TODO: support attn_probs

Check warning on line 351 in onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc:351: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// Check whether the attention bias alignment is good for memory efficient attention.
(attention_bias == nullptr || parameters.sequence_length % (4 * sizeof(T)) == 0) &&
(nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
Expand Down Expand Up @@ -369,6 +379,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
data.attn_probs = (nullptr == attn_probs) ? nullptr : reinterpret_cast<CudaT*>(attn_probs->MutableData<T>());
data.fused_runner = reinterpret_cast<void*>(fused_runner);
data.fused_cross_attention_kernel = fused_cross_attention_kernel;
data.use_flash_attention = use_flash_attention;
Expand Down
Loading

0 comments on commit 239df8b

Please sign in to comment.