diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index bc3f6e5a7..542523453 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -536,4 +536,4 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): if d <= 128: assert (dq - dq_ref).abs().max().item() < 1e-4 or (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() < 1e-4 or (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() - assert (dk - dk_ref).abs().max().item() < 1e-4 or (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() < 1e-4 or (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()