Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Fix] Support threads_per_head < 64 for wavefront size of 64 #6622

Merged
merged 11 commits into from
Nov 4, 2024

Conversation

jagadish-amd
Copy link
Contributor

@jagadish-amd jagadish-amd commented Oct 11, 2024

When launching apply_rotary_pos_half kernel, only threads_per_head of 64 is supported for wavefront size of 64.
This change adds support for threads_per_head < 64 such as 4, 8, 16.

Fixes the issue introduced in #5402

When launching apply_rotary_pos_half kernel, only threads_per_head of 64
is supported for wavefront size of 64.
This change adds support for threads_per_head < 64 such as 4, 8, 16.

Remove the condition to check ROCm and wavefront size check.

Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
@jagadish-amd jagadish-amd marked this pull request as ready for review October 12, 2024 03:27
@jagadish-amd jagadish-amd requested a review from awan-10 as a code owner October 12, 2024 03:27
@jagadish-amd
Copy link
Contributor Author

ping @jithunnair-amd @jeffdaily @loadams

@jagadish-amd
Copy link
Contributor Author

@loadams any comments on this PR?

@tjruwase tjruwase requested review from tjruwase and removed request for awan-10 October 18, 2024 15:45
Copy link
Contributor

@jithunnair-amd jithunnair-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but let's add a unit test to ensure this functionality can be tested on ROCm (and CUDA)

@loadams loadams self-assigned this Oct 28, 2024
@loadams
Copy link
Collaborator

loadams commented Oct 30, 2024

LGTM, but let's add a unit test to ensure this functionality can be tested on ROCm (and CUDA)

@jagadish-amd - thoughts on adding unit tests for this?

@jagadish-amd
Copy link
Contributor Author

LGTM, but let's add a unit test to ensure this functionality can be tested on ROCm (and CUDA)

@jagadish-amd - thoughts on adding unit tests for this?

I will add the unit tests. Thanks

@jagadish-amd jagadish-amd requested a review from tohtana as a code owner November 4, 2024 07:48
@jagadish-amd
Copy link
Contributor Author

LGTM, but let's add a unit test to ensure this functionality can be tested on ROCm (and CUDA)

@jagadish-amd - thoughts on adding unit tests for this?

I will add the unit tests. Thanks

@loadams I have added the test case to test the threads_per_head ,warp size alignment issue.
Unfortunately, I lost access to the AI model / node for which the "Assertion `false' failed" error had triggered on warp_size 64 node. Hence the values in the test cases are assumed here, but it still tests the intended fix. I will add the exact values in the future.
I noticed that the files in unit/ops/transformer/inference were refactored. I have used the InferenceBuilder().load() way to test the function apply_rotary_pos_emb. If this is not right, plz let me know, we can merge the PR without the newly added test. (and later folks can add the test cases, this change affects only warp_size 64 case / AMD Instinct device).

These are the results.
On warp_size = 32, test cases pass regardless of the changes in apply_rotary_pos_emb.cu
==================================== 4 passed, 2 warnings in 25.24s ====================================
On warp_size =64, with the fix.
==================================== 4 passed, 2 warnings in 5.58s =====================================
On warp_size =64, without the fix.
python: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/apply_rotary_pos_emb.hip:169: void launch_apply_rotary_pos_emb(T *, T *, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, float, hipStream_t, int) [T = float]: Assertion `false' failed.
Fatal Python error: Aborted

The test run is aborted (as expected) due to the error in kernel. Not sure if there is better way to handle this?

@loadams
Copy link
Collaborator

loadams commented Nov 4, 2024

LGTM, but let's add a unit test to ensure this functionality can be tested on ROCm (and CUDA)

@jagadish-amd - thoughts on adding unit tests for this?

I will add the unit tests. Thanks

@loadams I have added the test case to test the threads_per_head ,warp size alignment issue. Unfortunately, I lost access to the AI model / node for which the "Assertion `false' failed" error had triggered on warp_size 64 node. Hence the values in the test cases are assumed here, but it still tests the intended fix. I will add the exact values in the future. I noticed that the files in unit/ops/transformer/inference were refactored. I have used the InferenceBuilder().load() way to test the function apply_rotary_pos_emb. If this is not right, plz let me know, we can merge the PR without the newly added test. (and later folks can add the test cases, this change affects only warp_size 64 case / AMD Instinct device).

These are the results. On warp_size = 32, test cases pass regardless of the changes in apply_rotary_pos_emb.cu ==================================== 4 passed, 2 warnings in 25.24s ==================================== On warp_size =64, with the fix. ==================================== 4 passed, 2 warnings in 5.58s ===================================== On warp_size =64, without the fix. python: /opt/conda/envs/py_3.10/lib/python3.10/site-packages/deepspeed/ops/csrc/transformer/inference/csrc/apply_rotary_pos_emb.hip:169: void launch_apply_rotary_pos_emb(T *, T *, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, float, hipStream_t, int) [T = float]: Assertion `false' failed. Fatal Python error: Aborted

The test run is aborted (as expected) due to the error in kernel. Not sure if there is better way to handle this?

@jagadish-amd - this should be fine for now. I believe the only remaining thing for this PR is the CLA agreement, you should just need to reply to it with accept and company as AMD.

@jagadish-amd
Copy link
Contributor Author

@jagadish-amd please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.

@microsoft-github-policy-service agree [company="{your company}"]

Options:

  • (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
  • (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"

Contributor License Agreement

@microsoft-github-policy-service agree company="AMD"

@loadams loadams added this pull request to the merge queue Nov 4, 2024
Merged via the queue into deepspeedai:master with commit 2b41d62 Nov 4, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants