Skip to content

Commit

Permalink
Add GQA support for ROCm (#21032)
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan authored Jul 3, 2024
1 parent 4932e04 commit f39ee14
Show file tree
Hide file tree
Showing 14 changed files with 810 additions and 42 deletions.
2 changes: 1 addition & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ option(onnxruntime_ENABLE_TRITON "Enable Triton" OFF)

# composable kernel is managed automatically, unless user want to explicitly disable it, it should not be manually set
option(onnxruntime_USE_COMPOSABLE_KERNEL "Enable composable kernel for ROCm EP" ON)
option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON)
cmake_dependent_option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON "onnxruntime_USE_COMPOSABLE_KERNEL" OFF)
option(onnxruntime_USE_ROCBLAS_EXTENSION_API "Enable rocblas tuning for ROCm EP" OFF)
option(onnxruntime_USE_TRITON_KERNEL "Enable triton compiled kernel" OFF)
option(onnxruntime_BUILD_KERNEL_EXPLORER "Build Kernel Explorer for testing and profiling GPU kernels" OFF)
Expand Down
1 change: 0 additions & 1 deletion cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ set(contrib_ops_excluded_files
"cuda_contrib_kernels.h"
"inverse.cc"
"fused_conv.cc"
"bert/group_query_attention_helper.h"
"bert/group_query_attention.h"
"bert/group_query_attention.cc"
"bert/group_query_attention_impl.h"
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
const T* qkv_buffer,
T* present);

template <typename T>
Status LaunchStridedCopy(
cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
int max_threads_per_block);

template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
Expand Down
63 changes: 45 additions & 18 deletions onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,27 @@ namespace cuda {

template <typename T>
__global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides // coord (b,n,s,h)
) {
T* out, longlong4 out_strides, // coord (b,n,s,h)
const int32_t* in_seqlens_offset, const int32_t* out_seqlens_offset) {
const int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;

const int s_offset_i = in_seqlens_offset == nullptr ? 0 : in_seqlens_offset[b];
const int s_offset_o = out_seqlens_offset == nullptr ? 0 : out_seqlens_offset[b];

if (h < H) {
const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w;
const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w;
const int in_offset = b * in_strides.x + n * in_strides.y + (s + s_offset_i) * in_strides.z + h * in_strides.w;
const int out_offset = b * out_strides.x + n * out_strides.y + (s + s_offset_o) * out_strides.z + h * out_strides.w;
out[out_offset] = in[in_offset];
}
}

template <typename T>
__global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides // coord (b,n,s,h)
) {
T* out, longlong4 out_strides, // coord (b,n,s,h)
const int* in_seqlens_offset, const int* out_seqlens_offset) {
// Use when (H*)*num_heads > 1024
int h = threadIdx.x;
const int n = threadIdx.y;
Expand All @@ -37,9 +41,12 @@ __global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides,

const int h_step = blockDim.x;

const int s_offset_i = in_seqlens_offset == nullptr ? 0 : in_seqlens_offset[b];
const int s_offset_o = out_seqlens_offset == nullptr ? 0 : out_seqlens_offset[b];

while (h < H) {
const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w;
const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w;
const int in_offset = b * in_strides.x + n * in_strides.y + (s + s_offset_i) * in_strides.z + h * in_strides.w;
const int out_offset = b * out_strides.x + n * out_strides.y + (s + s_offset_o) * out_strides.z + h * out_strides.w;
out[out_offset] = in[in_offset];
h += h_step;
}
Expand Down Expand Up @@ -77,10 +84,11 @@ template <int NumBytes>
using ToBytes = typename ToByteType<NumBytes>::T;

template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block) {
Status LaunchStridedCopy(
cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
int max_threads_per_block) {
int batch_size = in_shape.x;
int num_heads = in_shape.y;
int sequence_length = in_shape.z;
Expand All @@ -102,11 +110,13 @@ Status LaunchStridedCopy(cudaStream_t stream,
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
reinterpret_cast<Bytes*>(out), out_strides,
in_seqlens_offset, out_seqlens_offset);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
reinterpret_cast<Bytes*>(out), out_strides,
in_seqlens_offset, out_seqlens_offset);
}
} else if (0 == (head_size % 2)) { // pack 2 element together
using Bytes = ToBytes<sizeof(T) * 2>;
Expand All @@ -120,27 +130,44 @@ Status LaunchStridedCopy(cudaStream_t stream,
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
reinterpret_cast<Bytes*>(out), out_strides,
in_seqlens_offset, out_seqlens_offset);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
reinterpret_cast<Bytes*>(out), out_strides,
in_seqlens_offset, out_seqlens_offset);
}
} else {
using Bytes = ToBytes<sizeof(T)>;
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), head_size, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
reinterpret_cast<Bytes*>(out), out_strides,
in_seqlens_offset, out_seqlens_offset);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), head_size, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
reinterpret_cast<Bytes*>(out), out_strides,
in_seqlens_offset, out_seqlens_offset);
}
}
return CUDA_CALL(cudaGetLastError());
}

template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block) {
const int* in_seqlens_offset = nullptr;
const int* out_seqlens_offset = nullptr;
return LaunchStridedCopy<T>(
stream, in, in_shape, in_strides, in_seqlens_offset,
out, out_strides, out_seqlens_offset,
max_threads_per_block);
}

template Status LaunchStridedCopy<float>(
cudaStream_t stream,
const float* in, int4 in_shape, longlong4 in_strides,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp
}

// Kernel to convert seqlens_k to position_ids
__global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, const int seqlen,
__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen,
const int batch_size) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
int b = tid / seqlen;
Expand All @@ -592,15 +592,15 @@ __global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids,
}

// Kernel to convert seqlens_k to position_ids
__global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids, const int batch_size) {
__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < batch_size) {
position_ids[tid] = seqlens_k[tid];
}
}

// Convert seqlens_k to position_ids
Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k,
Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k,
int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) {
const int seqlen = parameters.sequence_length;
const int batch_size = parameters.batch_size;
Expand Down
59 changes: 41 additions & 18 deletions onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
const int64_t* position_ids, // (1) or BxS
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int position_ids_format,
const bool interleaved, const int batch_stride, const int seq_stride,
const int head_stride) {
const bool interleaved,
int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous
) {
// B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
// Use .x in innermost loop to access global memory efficiently

Expand All @@ -40,10 +41,8 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
return;
}

const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;

const T* input_data = input + block_offset;
T* output_data = output + block_offset;
const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y;
T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y;

if (i >= rotary_embedding_dim) {
output_data[i] = input_data[i];
Expand Down Expand Up @@ -77,34 +76,58 @@ template <typename T>
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids,
const T* cos_cache, const T* sin_cache, const int batch_size,
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int /*max_sequence_length*/,
const int rotary_embedding_dim, const int max_sequence_length,
const int position_ids_format, const bool interleaved,
const int max_threads_per_block, const bool is_input_bnsh_format) {
int4 in_strides;
int4 out_strides;
if (is_input_bnsh_format) {
int in_head_stride = sequence_length * head_size;
int out_head_stride = sequence_length * head_size;
in_strides = int4{num_heads * in_head_stride, in_head_stride, in_head_stride / sequence_length, 1};
out_strides = int4{num_heads * out_head_stride, out_head_stride, out_head_stride / sequence_length, 1};
} else {
int in_head_stride = head_size;
int out_head_stride = head_size;
in_strides = int4{sequence_length * num_heads * in_head_stride, in_head_stride, num_heads * in_head_stride, 1};
out_strides = int4{sequence_length * num_heads * out_head_stride, out_head_stride, num_heads * out_head_stride, 1};
}
return LaunchRotaryEmbeddingKernel<T>(
stream, output, input, position_ids,
cos_cache, sin_cache, batch_size,
sequence_length, num_heads, head_size,
rotary_embedding_dim, max_sequence_length,
position_ids_format, interleaved,
max_threads_per_block,
in_strides, out_strides);
}

template <typename T>
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids,
const T* cos_cache, const T* sin_cache, const int batch_size,
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int /*max_sequence_length*/,
const int position_ids_format, const bool interleaved,
const int max_threads_per_block,
int4 in_strides, int4 out_strides // strides in bnsh coord
) {
// Note: Current implementation assumes head_size <= max_threads_per_block
// because head_size is currently large for LLaMA-2. For smaller head_size
// and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
// instead. This will require kernel changes to support.
ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block");
// strides in canonical bnsh coord, h is always contiguous (dim_stride == 1)
ORT_ENFORCE(in_strides.w == 1 && out_strides.w == 1, "head dim must contiguous");

int tpb = (head_size + 31) / 32 * 32;

const dim3 block(tpb);
const dim3 grid(sequence_length, batch_size, num_heads);

// Default input tensor shape is [batch, seq, hidden_size]
int head_stride = head_size;
int seq_stride = num_heads * head_stride;
int batch_stride = sequence_length * seq_stride;
if (is_input_bnsh_format) {
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
batch_stride = num_heads * head_stride;
}

assert(head_size <= max_threads_per_block);
RotaryEmbeddingBSNH<<<grid, block, 0, stream>>>(output, input, cos_cache, sin_cache, position_ids, sequence_length,
num_heads, head_size, rotary_embedding_dim, position_ids_format,
interleaved, batch_stride, seq_stride, head_stride);
interleaved, in_strides, out_strides);

return CUDA_CALL(cudaGetLastError());
}
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,26 @@ Status LaunchRotaryEmbeddingKernel(
const int max_threads_per_block,
const bool is_input_bnsh_format);

template <typename T>
Status LaunchRotaryEmbeddingKernel(
cudaStream_t stream,
T* output,
const T* input,
const int64_t* position_ids,
const T* cos_cache,
const T* sin_cache,
const int batch_size,
const int sequence_length,
const int num_heads,
const int head_size,
const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
const int max_threads_per_block,
int4 in_strides,
int4 out_strides);

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ Status ClassifyAttentionMode(AttentionType type,
const std::vector<const Tensor*>& past,
const std::vector<Tensor*>& present);

template <typename T>
Status LaunchStridedCopy(
hipStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
int max_threads_per_block);

template <typename T>
Status LaunchStridedCopy(hipStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
Expand Down
Loading

0 comments on commit f39ee14

Please sign in to comment.