Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Fix] Support threads_per_head < 64 for wavefront size of 64 #6622

Merged
merged 11 commits into from
Nov 4, 2024
10 changes: 9 additions & 1 deletion csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,15 @@ __global__ void apply_rotary_pos_half(T* mixed_query,

#if defined(__HIP_PLATFORM_AMD__) and ROCM_WAVEFRONT_SIZE == 64
#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \
if (threads_per_head == 64) { \
if (threads_per_head == 4) { \
LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \
} else if (threads_per_head == 8) { \
LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \
} else if (threads_per_head == 16) { \
LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \
} else if (threads_per_head == 32) { \
LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \
} else if (threads_per_head == 64) { \
LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \
} else { \
assert(false); \
Expand Down