Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: refactor flash_attn tests to use parameterized #20913

Merged
merged 2 commits into from
Jun 11, 2024

Conversation

cloudhan
Copy link
Contributor

@cloudhan cloudhan commented Jun 4, 2024

Use parameterized to decompose the huge test case. This will make adding ROCm support be possible.

@cloudhan cloudhan requested review from tianleiwu and aciddelgado June 4, 2024 05:41
@cloudhan
Copy link
Contributor Author

cloudhan commented Jun 4, 2024

When run with pytest -v, will produce something as follows

onnxruntime/test/python/transformers/test_flash_attn_cuda.py::TestMHA::test_mha_00_Config_batch_size_2_sequence_length_1_kv_sequence_length_128_past_sequence_length_0_past_sequence_length_0_num_heads_1_kv_num_heads_1_head_size_16_ PASSED                        [  0%]
onnxruntime/test/python/transformers/test_flash_attn_cuda.py::TestMHA::test_mha_01_Config_batch_size_2_sequence_length_1_kv_sequence_length_128_past_sequence_length_0_past_sequence_length_0_num_heads_1_kv_num_heads_1_head_size_256_ PASSED                       [  0%]
onnxruntime/test/python/transformers/test_flash_attn_cuda.py::TestMHA::test_mha_02_Config_batch_size_2_sequence_length_1_kv_sequence_length_128_past_sequence_length_0_past_sequence_length_0_num_heads_3_kv_num_heads_3_head_size_16_ PASSED                        [  0%]
onnxruntime/test/python/transformers/test_flash_attn_cuda.py::TestMHA::test_mha_03_Config_batch_size_2_sequence_length_1_kv_sequence_length_128_past_sequence_length_0_past_sequence_length_0_num_heads_3_kv_num_heads_3_head_size_256_ PASSED                       [  0%]
onnxruntime/test/python/transformers/test_flash_attn_cuda.py::TestMHA::test_mha_04_Config_batch_size_2_sequence_length_113_kv_sequence_length_211_past_sequence_length_0_past_sequence_length_0_num_heads_1_kv_num_heads_1_head_size_16_ PASSED                      [  0%]
onnxruntime/test/python/transformers/test_flash_attn_cuda.py::TestMHA::test_mha_05_Config_batch_size_2_sequence_length_113_kv_sequence_length_211_past_sequence_length_0_past_sequence_length_0_num_heads_1_kv_num_heads_1_head_size_256_ PASSED                     [  0%]
onnxruntime/test/python/transformers/test_flash_attn_cuda.py::TestMHA::test_mha_06_Config_batch_size_2_sequence_length_113_kv_sequence_length_211_past_sequence_length_0_past_sequence_length_0_num_heads_3_kv_num_heads_3_head_size_16_ PASSED                      [  0%]
onnxruntime/test/python/transformers/test_flash_attn_cuda.py::TestMHA::test_mha_07_Config_batch_size_2_sequence_length_113_kv_sequence_length_211_past_sequence_length_0_past_sequence_length_0_num_heads_3_kv_num_heads_3_head_size_256_ PASSED                     [  0%]
onnxruntime/test/python/transformers/test_flash_attn_cuda.py::TestMHA::test_mha_08_Config_batch_size_2_sequence_length_2048_kv_sequence_length_2048_past_sequence_length_0_past_sequence_length_0_num_heads_1_kv_num_heads_1_head_size_16_ PASSED                    [  0%]
onnxruntime/test/python/transformers/test_flash_attn_cuda.py::TestMHA::test_mha_09_Config_batch_size_2_sequence_length_2048_kv_sequence_length_2048_past_sequence_length_0_past_sequence_length_0_num_heads_1_kv_num_heads_1_head_size_256_ PASSED                   [  0%]
onnxruntime/test/python/transformers/test_flash_attn_cuda.py::TestMHA::test_mha_10_Config_batch_size_2_sequence_length_2048_kv_sequence_length_2048_past_sequence_length_0_past_sequence_length_0_num_heads_3_kv_num_heads_3_head_size_16_ PASSED                    [  0%]
... and more

@cloudhan cloudhan force-pushed the guangyunhan/refactor-flash-attn-test branch from 490542e to 94a0da7 Compare June 4, 2024 06:07
@cloudhan cloudhan force-pushed the guangyunhan/refactor-flash-attn-test branch from 94a0da7 to d216567 Compare June 5, 2024 07:03
@tianleiwu
Copy link
Contributor

tianleiwu commented Jun 7, 2024

To reduce combinations in build pipeline

Instead of

    for b in batches:
        for s in seqs:
            for n in num_heads:
                for h in h_sizes:

we can use something like:

for b, s, n, h in configs:

or

test_cases=10
for i in range(test_cases):
   b = random.choice(batches)
   s = seqs[i % len(seqs)]
   n = num_heads[i % len(num_heads)]
   h = h_sizes[i % len(h_sizes)]

For example,

    batches = [1, 5]
    seqs = [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]
    num_h = [1, 6, 16]
    h_sizes = [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]

There are 2 * 11 * 3 * 10 = 660 combinations.

In pipeline, we can use around 10 and still have a good coverage:

  [batch=2, seq = 97, n =1, h = 32]
  [batch=1, seq = 128, n=6, h = 40]
  [batch=5, seq = 200, n=2, h = 64]
  [batch=1, seq = 256, n=4, h = 80]
  [batch=2, seq = 257, n=8, h = 96]
  [batch=3, seq = 384, n=1, h = 128]
  [batch=4, seq = 512, n=2, h = 160]
  [batch=1, seq = 1024, n=4, h = 192]
  [batch=2, seq = 1025, n=8, h = 224]
  [batch=3, seq = 2048, n=16, h=256]

@tianleiwu tianleiwu merged commit 67c8bef into main Jun 11, 2024
94 of 96 checks passed
@tianleiwu tianleiwu deleted the guangyunhan/refactor-flash-attn-test branch June 11, 2024 22:57
@cloudhan cloudhan mentioned this pull request Jun 13, 2024
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants