Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DecoderMaskedMultiHeadAttention CPU kernel. #22292

Merged
merged 14 commits into from
Oct 12, 2024
12 changes: 6 additions & 6 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1167,17 +1167,17 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>mask_index</tt> (optional) : M</dt>
<dd>Mask values of shape (batch_size, total_sequence_length) or (batch_size, kv_sequence_length)</dd>
<dt><tt>attention_bias</tt> (optional) : T</dt>
<dd>additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)</dd>
<dd>additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dd>past state for key with shape (batch_size, num_heads, past_sequence_length, head_size) for self attentionWhen past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size). The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.</dd>
<dt><tt>past_value</tt> (optional) : T</dt>
<dd>past state for value with shape (batch_size, num_heads, past_sequence_length, head_size) for self attentionWhen past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size). </dd>
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
<dd>When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).Cross Attention doesn't need this input.</dd>
<dt><tt>beam_width</tt> (optional) : M</dt>
<dd>The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.</dd>
<dd>The beam width that is being used while decoding. If not provided, the beam width will be assumed to be 1.</dd>
<dt><tt>cache_indirection</tt> (optional) : M</dt>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration</dd>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration</dd>
<dt><tt>bias</tt> (optional) : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
</dl>
Expand All @@ -1192,7 +1192,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>present_value</tt> (optional) : T</dt>
<dd>present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
<dt><tt>qk</tt> (optional) : V</dt>
<dd>normalized Q * K, of shape (batch_size, num_heads, 1, head_size). </dd>
<dd>normalized Q * K, of shape (batch_size, num_heads, 1, total_sequence_length). </dd>
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
</dl>

#### Type Constraints
Expand Down Expand Up @@ -1261,9 +1261,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>past_sequence_length</tt> : M</dt>
<dd>When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).</dd>
<dt><tt>beam_width</tt> (optional) : M</dt>
<dd>The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.</dd>
<dd>The beam width that is being used while decoding. If not provided, the beam width will be assumed to be 1.</dd>
<dt><tt>cache_indirection</tt> (optional) : M</dt>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration</dd>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration</dd>
</dl>

#### Outputs
Expand Down
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ Do not modify directly.*
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|CropAndResize|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *in* crop_size:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int32)|
|DecoderMaskedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* mask_index:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* beam_width:**M**<br> *in* cache_indirection:**M**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**V**|1+|**T** = tensor(float)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)<br/> **T2** = tensor(float)|
|DynamicQuantizeLSTM|*in* X:**T**<br> *in* W:**T2**<br> *in* R:**T2**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* W_scale:**T**<br> *in* W_zero_point:**T2**<br> *in* R_scale:**T**<br> *in* R_zero_point:**T2**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(float)<br/> **T1** = tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|DynamicQuantizeMatMul|*in* A:**T1**<br> *in* B:**T2**<br> *in* b_scale:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
Expand Down
39 changes: 39 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,45 @@ struct AttentionParameters {
AttentionQkvFormat qkv_format;
};

struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters {
int beam_width = 1;

// Only NeoX style rotary embedding is supported
int rotary_embedding_dim = 0;
int t_step = 0;

// Whether to use multihead attention(excludes matmul and bias)
bool is_mha = false;
bool is_cross_attention = false;
bool is_packed_qkv = false;

// Useful to better use global memory bandwidth on certain CUDA architectures.
// Turned off by default for now until we fully understand performance implications
// for all types of workloads.
// Can be turned on by appropriate environment variable (see attention_common.h).
bool kv_data_in_flight = false;

void* q = nullptr;
void* q_bias = nullptr;

void* k = nullptr;
void* k_bias = nullptr;

void* v = nullptr;
void* v_bias = nullptr;

void* attention_bias = nullptr;

void* k_cache = nullptr;
void* v_cache = nullptr;

void* out = nullptr;
void* out_qk = nullptr;

const int32_t* cache_indir = nullptr;
const int32_t* mask = nullptr; // [B, total_sequence_length]
};

// Parameters deduced from node attributes and inputs/outputs.
struct PackedAttentionParameters {
int batch_size;
Expand Down
Loading
Loading