From c1d146cbd5becd9e33634b1310c2d27a49c7e862 Mon Sep 17 00:00:00 2001 From: milesvant <26556534+milesvant@users.noreply.github.com> Date: Tue, 15 Oct 2024 13:54:40 -0700 Subject: [PATCH] Fix copy-paste error in hopper tests (#1279) --- hopper/test_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index bc3f6e5a7..542523453 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -536,4 +536,4 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): if d <= 128: assert (dq - dq_ref).abs().max().item() < 1e-4 or (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() < 1e-4 or (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() - assert (dk - dk_ref).abs().max().item() < 1e-4 or (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() < 1e-4 or (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()