-
Notifications
You must be signed in to change notification settings - Fork 351
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support chunk dynamo converter (#2401)
- Loading branch information
Showing
3 changed files
with
161 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import torch | ||
from parameterized import parameterized | ||
from torch.testing._internal.common_utils import run_tests | ||
|
||
from .harness import DispatchTestCase | ||
|
||
|
||
class TestChunkConverter(DispatchTestCase): | ||
@parameterized.expand( | ||
[ | ||
((1,), 3, 0), | ||
((3,), 3, 0), | ||
((4,), 3, 0), | ||
((6,), 3, 0), | ||
((3,), 1, -1), | ||
((3,), 3, -1), | ||
((3,), 4, -1), | ||
] | ||
) | ||
def test_chunk_1D(self, shape, chunks, dim): | ||
class TestChunk(torch.nn.Module): | ||
def forward(self, input): | ||
out = torch.ops.aten.chunk.default(input, chunks, dim) | ||
return out | ||
|
||
input = [torch.randn(shape)] | ||
self.run_test( | ||
TestChunk(), | ||
input, | ||
) | ||
|
||
@parameterized.expand( | ||
[ | ||
((3, 4), 1, 0), | ||
((3, 4), 3, 0), | ||
((3, 4), 4, 0), | ||
((3, 4), 2, -2), | ||
((3, 4), 6, -2), | ||
((3, 4), 3, 1), | ||
((3, 4), 4, 1), | ||
((3, 4), 5, -1), | ||
] | ||
) | ||
def test_chunk_2D(self, shape, chunks, dim): | ||
class TestChunk(torch.nn.Module): | ||
def forward(self, input): | ||
out = torch.ops.aten.chunk.default(input, chunks, dim) | ||
return out | ||
|
||
input = [torch.randn(shape)] | ||
self.run_test( | ||
TestChunk(), | ||
input, | ||
) | ||
|
||
@parameterized.expand( | ||
[ | ||
((3, 4, 2), 1, 0), | ||
((3, 4, 2), 3, -3), | ||
((3, 4, 2), 3, 1), | ||
((3, 4, 2), 4, 1), | ||
((3, 4, 2), 6, -2), | ||
((3, 4, 2), 1, 2), | ||
((3, 4, 2), 3, -1), | ||
((3, 4, 2), 4, -1), | ||
] | ||
) | ||
def test_chunk_3D(self, shape, chunks, dim): | ||
class TestChunk(torch.nn.Module): | ||
def forward(self, input): | ||
out = torch.ops.aten.chunk.default(input, chunks, dim) | ||
return out | ||
|
||
input = [torch.randn(shape)] | ||
self.run_test( | ||
TestChunk(), | ||
input, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |