Skip to content

Commit

Permalink
remove padding
Browse files Browse the repository at this point in the history
  • Loading branch information
aciddelgado committed Nov 7, 2023
1 parent 4801fe8 commit 1e1a25c
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 61 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ GroupQueryAttention<T>::GroupQueryAttention(const OpKernelInfo& info)
num_heads_ = static_cast<int>(num_heads);
kv_num_heads_ = static_cast<int>(kv_num_heads);
is_unidirectional_ = true;
left_padding_ = info.GetAttrOrDefault<int64_t>("left_padding_last_token", 0) == 1;
// left_padding_ = info.GetAttrOrDefault<int64_t>("left_padding_last_token", 0) == 1;
is_past_bsnh_ = false; // info.GetAttrOrDefault<int64_t>("is_past_bsnh", 1) == 1;
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

Expand Down Expand Up @@ -93,7 +93,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
scale_,
device_prop.maxThreadsPerBlock));
parameters.is_unidirectional = is_unidirectional_;
parameters.left_padding = left_padding_;
// parameters.left_padding = left_padding_;
int sequence_length = parameters.sequence_length;

TensorShapeVector output_shape(3);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class GroupQueryAttention final : public CudaKernel {
protected:
int num_heads_; // number of attention heads
int kv_num_heads_; // different for k and v for group query attention
bool left_padding_; // shifts last token to end of buffer
// bool left_padding_; // shifts last token to end of buffer
bool is_unidirectional_; // causal
bool is_past_bsnh_;
float scale_;
Expand Down
108 changes: 54 additions & 54 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -468,54 +468,54 @@ Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, i
return CUDA_CALL(cudaGetLastError());
}

// Kernel to append new kv to kv buffer in place
template <typename T>
__global__ void LeftPadLast(const int max_seqlen,
T* kv_buff,
const int* seqlens_k) { // refers to kv buff; otherwise bnsh
const int h = threadIdx.x;
const int n = blockIdx.x;
const int b = blockIdx.y;

const int num_heads = gridDim.x;
const int H = blockDim.x;

const int present_batch_stride = max_seqlen * num_heads * H;
const int present_row_stride = num_heads * H;
const int present_head_stride = H;

// kv_buff: BTNH or BNTH with buffered memory for new
// new_kv: BLNH

const int s = seqlens_k[b];

const int in_offset = b * present_batch_stride + s * present_row_stride + n * present_head_stride + h;
const int out_offset = b * present_batch_stride + (max_seqlen - 1) * present_row_stride + n * present_head_stride + h;
kv_buff[out_offset] = kv_buff[in_offset];
}

// Concat new to kv buffer in place
template <typename T>
Status LaunchLeftPadLast(contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<T>& data,
cudaStream_t stream,
const int max_threads_per_block) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int num_heads = parameters.num_heads;
const int head_size = parameters.head_size;

// Indicates past sequence_length of each sequence
const int* seqlens_k = reinterpret_cast<const int*>(data.seqlens_k);

const int H = head_size / 4;
const dim3 grid(num_heads, batch_size, 1);
const dim3 block(H, 1, 1);
LeftPadLast<float2><<<grid, block, 0, stream>>>(sequence_length,
reinterpret_cast<float2*>(data.output),
seqlens_k);
return CUDA_CALL(cudaGetLastError());
}
// // Kernel to append new kv to kv buffer in place
// template <typename T>
// __global__ void LeftPadLast(const int max_seqlen,
// T* kv_buff,
// const int* seqlens_k) { // refers to kv buff; otherwise bnsh
// const int h = threadIdx.x;
// const int n = blockIdx.x;
// const int b = blockIdx.y;

// const int num_heads = gridDim.x;
// const int H = blockDim.x;

// const int present_batch_stride = max_seqlen * num_heads * H;
// const int present_row_stride = num_heads * H;
// const int present_head_stride = H;

// // kv_buff: BTNH or BNTH with buffered memory for new
// // new_kv: BLNH

// const int s = seqlens_k[b];

// const int in_offset = b * present_batch_stride + s * present_row_stride + n * present_head_stride + h;
// const int out_offset = b * present_batch_stride + (max_seqlen - 1) * present_row_stride + n * present_head_stride + h;
// kv_buff[out_offset] = kv_buff[in_offset];
// }

// // Concat new to kv buffer in place
// template <typename T>
// Status LaunchLeftPadLast(contrib::GroupQueryAttentionParameters& parameters,
// GroupQueryAttentionData<T>& data,
// cudaStream_t stream,
// const int max_threads_per_block) {
// const int batch_size = parameters.batch_size;
// const int sequence_length = parameters.sequence_length;
// const int num_heads = parameters.num_heads;
// const int head_size = parameters.head_size;

// // Indicates past sequence_length of each sequence
// const int* seqlens_k = reinterpret_cast<const int*>(data.seqlens_k);

// const int H = head_size / 4;
// const dim3 grid(num_heads, batch_size, 1);
// const dim3 block(H, 1, 1);
// LeftPadLast<float2><<<grid, block, 0, stream>>>(sequence_length,
// reinterpret_cast<float2*>(data.output),
// seqlens_k);
// return CUDA_CALL(cudaGetLastError());
// }

////////// Launch Kernels

Expand Down Expand Up @@ -614,9 +614,9 @@ Status FlashAttention(
reinterpret_cast<void*>(data.out_accum)));
}

if (parameters.left_padding && parameters.is_prompt) {
ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock));
}
// if (parameters.left_padding && parameters.is_prompt) {
// ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock));
// }

DUMP_TENSOR_INIT();
DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size);
Expand Down Expand Up @@ -721,9 +721,9 @@ Status EfficientAttention(
p.has_custom_right_padding = true;
run_memory_efficient_attention(p);

if (parameters.left_padding && parameters.is_prompt) {
ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock));
}
// if (parameters.left_padding && parameters.is_prompt) {
// ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock));
// }

DUMP_TENSOR_INIT();
DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1004,10 +1004,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr("left_padding_last_token",
"Copy last token to last index of buffer. Default is 0; 1 when true.",
AttributeProto::INT,
OPTIONAL_VALUE)
// .Attr("left_padding_last_token",
// "Copy last token to last index of buffer. Default is 0; 1 when true.",
// AttributeProto::INT,
// OPTIONAL_VALUE)
.Input(0,
"query",
"Query with shape (batch_size, sequence_length, hidden_size)",
Expand Down

0 comments on commit 1e1a25c

Please sign in to comment.