Skip to content

Commit

Permalink
Add unittest for transposed grouped convolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
jonpryai committed Aug 16, 2024
1 parent 4360ae3 commit 4742149
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions tests/unittest/ops/test_transpose_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,38 @@ def _test_transpose_conv2d(
copy_op=False,
test_name="transpose_conv2d",
dtype="float16",
grouped=False,
):
target = detect_target()
groups = 256 if grouped else 1
X = Tensor(
shape=[IntImm(batch), 28, 28, 256],
dtype=dtype,
name="input_0",
is_input=True,
)
W = Tensor(
shape=[256, 2, 2, 256],
shape=[256, 2, 2, 256 // groups ],
dtype=dtype,
name="input_1",
is_input=True,
)
OP = ops.transposed_conv2d(stride=2, pad=0, dilate=1)
OP = ops.transposed_conv2d(stride=2, pad=0, dilate=1, group=groups)
if copy_op:
OP = ops.transposed_conv2d(**OP._get_op_attributes())

Y = OP(X, W)
Y._attrs["name"] = "output_0"
Y._attrs["is_output"] = True
module = compile_model(Y, target, "./tmp", test_name)

X_pt = get_random_torch_tensor([batch, 256, 28, 28], dtype=dtype)
W_pt = get_random_torch_tensor([256, 256, 2, 2], dtype=dtype)
Y_pt = torch.nn.functional.conv_transpose2d(X_pt, W_pt, padding=0, stride=2)
W_pt = get_random_torch_tensor([256, 256 // groups, 2, 2], dtype=dtype)
Y_pt = torch.nn.functional.conv_transpose2d(X_pt, W_pt, padding=0, stride=2, groups=groups)

x = X_pt.permute((0, 2, 3, 1)).contiguous()
w = W_pt.permute((0, 2, 3, 1)).contiguous()

y = torch.empty_like(Y_pt).permute((0, 2, 3, 1)).contiguous()
module.run_with_tensors({"input_0": x, "input_1": w}, [y])
y_transpose = y.permute((0, 3, 1, 2))
Expand All @@ -68,6 +72,13 @@ def _test_transpose_conv2d(
else:
self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2))

def test_grouped(self):
self._test_transpose_conv2d(
test_name="transpose_conv2d_fp16_grouped",
dtype="float16",
grouped=True
)

def test_fp16(self):
self._test_transpose_conv2d(
test_name="transpose_conv2d_fp16",
Expand All @@ -92,6 +103,8 @@ def test_fp32_sm80(self):
)




filter_test_cases_by_test_env(Conv2dTransposeTestCase)

if __name__ == "__main__":
Expand Down

0 comments on commit 4742149

Please sign in to comment.