diff --git a/tests/unittest/ops/test_transpose_conv2d.py b/tests/unittest/ops/test_transpose_conv2d.py index 363418ece..90da294ec 100644 --- a/tests/unittest/ops/test_transpose_conv2d.py +++ b/tests/unittest/ops/test_transpose_conv2d.py @@ -32,8 +32,10 @@ def _test_transpose_conv2d( copy_op=False, test_name="transpose_conv2d", dtype="float16", + grouped=False, ): target = detect_target() + groups = 256 if grouped else 1 X = Tensor( shape=[IntImm(batch), 28, 28, 256], dtype=dtype, @@ -41,25 +43,27 @@ def _test_transpose_conv2d( is_input=True, ) W = Tensor( - shape=[256, 2, 2, 256], + shape=[256, 2, 2, 256 // groups ], dtype=dtype, name="input_1", is_input=True, ) - OP = ops.transposed_conv2d(stride=2, pad=0, dilate=1) + OP = ops.transposed_conv2d(stride=2, pad=0, dilate=1, group=groups) if copy_op: OP = ops.transposed_conv2d(**OP._get_op_attributes()) + Y = OP(X, W) Y._attrs["name"] = "output_0" Y._attrs["is_output"] = True module = compile_model(Y, target, "./tmp", test_name) X_pt = get_random_torch_tensor([batch, 256, 28, 28], dtype=dtype) - W_pt = get_random_torch_tensor([256, 256, 2, 2], dtype=dtype) - Y_pt = torch.nn.functional.conv_transpose2d(X_pt, W_pt, padding=0, stride=2) + W_pt = get_random_torch_tensor([256, 256 // groups, 2, 2], dtype=dtype) + Y_pt = torch.nn.functional.conv_transpose2d(X_pt, W_pt, padding=0, stride=2, groups=groups) x = X_pt.permute((0, 2, 3, 1)).contiguous() w = W_pt.permute((0, 2, 3, 1)).contiguous() + y = torch.empty_like(Y_pt).permute((0, 2, 3, 1)).contiguous() module.run_with_tensors({"input_0": x, "input_1": w}, [y]) y_transpose = y.permute((0, 3, 1, 2)) @@ -68,6 +72,13 @@ def _test_transpose_conv2d( else: self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2)) + def test_grouped(self): + self._test_transpose_conv2d( + test_name="transpose_conv2d_fp16_grouped", + dtype="float16", + grouped=True + ) + def test_fp16(self): self._test_transpose_conv2d( test_name="transpose_conv2d_fp16", @@ -92,6 +103,8 @@ def test_fp32_sm80(self): ) + + filter_test_cases_by_test_env(Conv2dTransposeTestCase) if __name__ == "__main__":