From 7409bae64c85790927a679cc86a01248ffb4adb6 Mon Sep 17 00:00:00 2001 From: kaixih Date: Mon, 21 Oct 2024 17:00:04 +0000 Subject: [PATCH] Adjusted atol/rtol for jax sdpa tests --- tests/nn_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/nn_test.py b/tests/nn_test.py index df719256a921..0856b259c190 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -99,8 +99,8 @@ def testDotProductAttention(self, dtype, group_num, use_vmap, impl): self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) - self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02) - self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02) + self.assertAllClose(dK_ref, dK_ans, rtol=.01, atol=.01) + self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01) @parameterized.product( mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'), @@ -164,10 +164,10 @@ def testDotProductAttentionMask(self, mask_mode): self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad)) self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) - self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) + self.assertAllClose(dQ_ref, dQ_ans, rtol=.02, atol=.02) self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02) - self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02) - self.assertAllClose(dbias_ref, dbias_ans, rtol=.03, atol=.03) + self.assertAllClose(dV_ref, dV_ans, rtol=.01, atol=.01) + self.assertAllClose(dbias_ref, dbias_ans, rtol=.02, atol=.02) @parameterized.product( batch_size=[1, 16], @@ -224,7 +224,7 @@ def bwd_ans(x, bias, mask): else: _, dbias_ref, _ = bwd_ref(x, bias, mask) _, dbias_ans, _ = bwd_ans(x, bias, mask) - self.assertAllClose(dbias_ans, dbias_ref, rtol=.03, atol=.03) + self.assertAllClose(dbias_ans, dbias_ref, rtol=.02, atol=.02) @jtu.skip_on_flag("jax_skip_slow_tests", True) def testSoftplusGrad(self):