From aef8f6ec3727568539ceb2713c513e0eeba75934 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Thu, 17 Oct 2024 18:16:34 -0700 Subject: [PATCH] Break down test_group_fused_layernorm_sigmoid_mul to avoid timeout (#1031) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/1031 test_group_fused_layernorm_sigmoid_mul can timeout. Break it down into smaller pieces. Reviewed By: ColinPeppler Differential Revision: D64557485 fbshipit-source-id: 216d03ac8d4af2bbe2608c8d655b2bc17230c750 --- .../ops/test_layernorm_sigmoid_mul.py | 33 ++++++++++++++----- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/tests/unittest/ops/test_layernorm_sigmoid_mul.py b/tests/unittest/ops/test_layernorm_sigmoid_mul.py index 0d41ff55d..756734e31 100644 --- a/tests/unittest/ops/test_layernorm_sigmoid_mul.py +++ b/tests/unittest/ops/test_layernorm_sigmoid_mul.py @@ -912,14 +912,6 @@ def test_group_fused_layernorm_sigmoid_mul(self, dtype: str): dtype=dtype, ) - # Make sure we test the boundary between being able to fit the arguments in constant memory vs not. - for num_groups in range(38, 41): - self._test_group_fused_layernorm_sigmoid_mul( - [[1024, 256]] * num_groups, - use_size_op=True, - dtype=dtype, - ) - # < 1024 kernel self._test_group_fused_layernorm_sigmoid_mul( [[4, 16]], @@ -986,6 +978,23 @@ def test_group_fused_layernorm_sigmoid_mul(self, dtype: str): [[128, 1025], [128, 0], [128, 1023]], dtype=dtype, ) + + @parameterized.expand( + [ + param("float16"), + param("float32"), + param("bfloat16"), + ] + ) + def test_group_fused_layernorm_sigmoid_mul_long(self, dtype: str): + # Make sure we test the boundary between being able to fit the arguments in constant memory vs not. + for num_groups in range(38, 41): + self._test_group_fused_layernorm_sigmoid_mul( + [[1024, 256]] * num_groups, + use_size_op=True, + dtype=dtype, + ) + # Ditto boundary test for num_groups_divided_by_3 in range(12, 15): self._test_group_fused_layernorm_sigmoid_mul( @@ -993,6 +1002,14 @@ def test_group_fused_layernorm_sigmoid_mul(self, dtype: str): dtype=dtype, ) + @parameterized.expand( + [ + param("float16"), + param("float32"), + param("bfloat16"), + ] + ) + def test_group_fused_layernorm_sigmoid_mul_nd(self, dtype: str): # ND self._test_group_fused_layernorm_sigmoid_mul( [[2, 512, 256, 16], [2, 512, 128, 4]],