Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repkv_backward_kernel2 and repkv_kernel2 (llama3 branch) #771

Open
wants to merge 6 commits into
base: llama3
Choose a base branch
from

Conversation

insop
Copy link

@insop insop commented Sep 28, 2024

Changes

Add repkv_backward_kernel2

  • improve repkv_backward_kernel1 by reducing thread used per @karpathy's suggestion

Also add repkv_kernel2 simiar to backward_kernel2

Here is the test output for repkv_backward_kernel2

# ./repkv_backward 2                                                                                              │
Using kernel 2                                                                                                                                                                    │
Checking block size 32.                                                                                                                                                           │
0.680375 0.680375                                                                                                                                                                 │
-0.211234 -0.211234                                                                                                                                                               │
0.566198 0.566198                                                                                                                                                                 │
0.596880 0.596880                                                                                                                                                                 │
0.823295 0.823295                                                                                                                                                                 │
Checking block size 64.                                                                                                                                                           │
0.680375 0.680375                                                                                                                                                                 │
-0.211234 -0.211234                                                                                                                                                               │
0.566198 0.566198                                                                                                                                                                 │
0.596880 0.596880                                                                                                                                                                 │
0.823295 0.823295                                                                                                                                                                 │
Checking block size 128.                                                                                                                                                          │
0.680375 0.680375                                                                                                                                                                 │
-0.211234 -0.211234                                                                                                                                                               │
0.566198 0.566198                                                                                                                                                                 │
0.596880 0.596880                                                                                                                                                                 │
0.823295 0.823295                                                                                                                                                                 │
Checking block size 256.                                                                                                                                                          │
0.680375 0.680375                                                                                                                                                                 │
-0.211234 -0.211234                                                                                                                                                               │
0.566198 0.566198                                                                                                                                                                 │
0.596880 0.596880                                                                                                                                                                 │
0.823295 0.823295                                                                                                                                                                 │
Checking block size 512.                                                                                                                                                          │
0.680375 0.680375                                                                                                                                                                 │
-0.211234 -0.211234                                                                                                                                                               │
0.566198 0.566198                                                                                                                                                                 │
0.596880 0.596880                                                                                                                                                                 │
0.823295 0.823295                                                                                                                                                                 │
Checking block size 1024.                                                                                                                                                         │
0.680375 0.680375                                                                                                                                                                 │
-0.211234 -0.211234                                                                                                                                                               │
0.566198 0.566198                                                                                                                                                                 │
0.596880 0.596880                                                                                                                                                                 │
0.823295 0.823295                                                                                                                                                                 │
All results match. Starting benchmarks.                                                                                                                                           │
                                                                                                                                                                                  │
block_size   32 time 1.8824 ms                                                                                                                                                    │
block_size   64 time 0.9740 ms                                                                                                                                                    │
block_size  128 time 0.9716 ms                                                                                                                                                    │
block_size  256 time 0.9740 ms                                                                                                                                                    │
block_size  512 time 1.0151 ms                                                                                                                                                    │
block_size 1024 time 1.0725 ms 

Execution time is improved compared to kernel1 time shown below from previous PR (#764)

All results match. Starting benchmarks.

block_size   32 time 3.2461 ms
block_size   64 time 1.7509 ms
block_size  128 time 1.7374 ms
block_size  256 time 1.7441 ms
block_size  512 time 1.8092 ms
block_size 1024 time 2.0443 ms

Here is the test output for repkv_kernel2

# ./repkv 2                                                                                                       │
Using kernel 2                                                                                                                                                                    │
Checking block size 32.                                                                                                                                                           │
0.680375 0.680375                                                                                                                                                                 │
-0.211234 -0.211234                                                                                                                                                               │
0.566198 0.566198                                                                                                                                                                 │
0.596880 0.596880                                                                                                                                                                 │
0.823295 0.823295                                                                                                                                                                 │
Checking block size 64.                                                                                                                                                           │
0.680375 0.680375                                                                                                                                                                 │
-0.211234 -0.211234                                                                                                                                                               │
0.566198 0.566198                                                                                                                                                                 │
0.596880 0.596880                                                                                                                                                                 │
0.823295 0.823295                                                                                                                                                                 │
Checking block size 128.                                                                                                                                                          │
0.680375 0.680375                                                                                                                                                                 │
-0.211234 -0.211234                                                                                                                                                               │
0.566198 0.566198                                                                                                                                                                 │
0.596880 0.596880                                                                                                                                                                 │
0.823295 0.823295                                                                                                                                                                 │
Checking block size 256.                                                                                                                                                          │
0.680375 0.680375                                                                                                                                                                 │
-0.211234 -0.211234                                                                                                                                                               │
0.566198 0.566198                                                                                                                                                                 │
0.596880 0.596880                                                                                                                                                                 │
0.823295 0.823295                                                                                                                                                                 │
Checking block size 512.                                                                                                                                                          │
0.680375 0.680375                                                                                                                                                                 │
-0.211234 -0.211234                                                                                                                                                               │
0.566198 0.566198                                                                                                                                                                 │
0.596880 0.596880                                                                                                                                                                 │
0.823295 0.823295                                                                                                                                                                 │
Checking block size 1024.                                                                                                                                                         │
0.680375 0.680375                                                                                                                                                                 │
-0.211234 -0.211234                                                                                                                                                               │
0.566198 0.566198                                                                                                                                                                 │
0.596880 0.596880                                                                                                                                                                 │
0.823295 0.823295                                                                                                                                                                 │
All results match. Starting benchmarks.                                                                                                                                           │
                                                                                                                                                                                  │
block_size   32 time 1.7765 ms                                                                                                                                                    │
block_size   64 time 0.9856 ms                                                                                                                                                    │
block_size  128 time 0.9781 ms                                                                                                                                                    │
block_size  256 time 0.9887 ms                                                                                                                                                    │
block_size  512 time 1.0429 ms                                                                                                                                                    │
block_size 1024 time 1.1434 ms

Execution time is improved compared to kernel1

block_size   32 time 3.6582 ms                                                                                                                                                    │
block_size   64 time 1.5909 ms                                                                                                                                                    │
block_size  128 time 1.5868 ms                                                                                                                                                    │
block_size  256 time 1.5798 ms                                                                                                                                                    │
block_size  512 time 1.6164 ms                                                                                                                                                    │
block_size 1024 time 1.8981 ms  

@insop insop changed the title Add repkv_backward_kernel2, reduced thread by using dinp as index instead of dout Add repkv_backward_kernel2, reduced thread by using dinp as index instead of dout Sep 28, 2024
@insop
Copy link
Author

insop commented Sep 28, 2024

@karpathy , PTAL.

@insop insop changed the title Add repkv_backward_kernel2, reduced thread by using dinp as index instead of dout Add repkv_backward_kernel2 and repkv_kernel2 Sep 28, 2024
@insop insop changed the title Add repkv_backward_kernel2 and repkv_kernel2 Add repkv_backward_kernel2 and repkv_kernel2 (llama3 branch) Oct 1, 2024
@insop
Copy link
Author

insop commented Oct 2, 2024

@karpathy,

Please let me know if you have any feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant