Skip to content

Commit

Permalink
adjust gqa flash attention test threshold for rocm
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jul 16, 2024
1 parent 4c3c809 commit ec83548
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions onnxruntime/test/python/transformers/test_flash_attn_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
rtol=0.002,
atol=0.002,
rtol=0.001,
atol=0.005,
)
parity_check_gqa_prompt_no_buff(
config,
Expand All @@ -45,8 +45,8 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
rtol=0.002,
atol=0.002,
rtol=0.001,
atol=0.005,
)

@parameterized.expand(gqa_past_flash_attention_test_cases())
Expand All @@ -67,8 +67,8 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
rtol=0.002,
atol=0.002,
rtol=0.001,
atol=0.005,
)
parity_check_gqa_past_no_buff(
config,
Expand All @@ -77,8 +77,8 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
rtol=0.002,
atol=0.002,
rtol=0.001,
atol=0.005,
)


Expand Down

0 comments on commit ec83548

Please sign in to comment.