Skip to content

Commit

Permalink
[MLIR][TORCH] Add support for clone op with channels last memory format
Browse files Browse the repository at this point in the history
Fixes llvm#1829

Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 authored and gpetters94 committed May 8, 2023
1 parent d6aaf48 commit f58a610
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 4 deletions.
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@
"Convolution2DStaticModule_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseBinaryStaticShapeModule_basic",
"ReturnThreeTensorFloat32_basic",
Expand Down Expand Up @@ -425,6 +426,7 @@
# and very few tests work yet.
TOSA_PASS_SET = {
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseUnaryModule_basic",
"ElementwiseBinaryModule_basic",
Expand Down
6 changes: 4 additions & 2 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (!clone.getMemoryFormat().getType().isa<Torch::NoneType>() &&
(!matchPattern(clone.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)) ||
memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
clone.emitError("unimplemented: only default memory format is supported");
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) {
clone.emitError("unimplemented: only contiguous and channels last memory "
"format is supported");
return nullptr;
}
return payloadArgs[0];
Expand Down
6 changes: 4 additions & 2 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4077,9 +4077,11 @@ class ConvertAtenCloneOp : public OpConversionPattern<AtenOpT> {
if (!op.getMemoryFormat().getType().template isa<Torch::NoneType>() &&
(!matchPattern(op.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)) ||
memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
memoryFormat != torch_upstream::MemoryFormat::ChannelsLast))) {
return op.emitError(
"unimplemented: only default memory format is supported");
"unimplemented: only contiguous and channels last memory "
"format is supported");
}
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
Expand Down
24 changes: 24 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1963,6 +1963,30 @@ def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseCloneChannelsLastMemoryFormatModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.clone(x, memory_format=torch.channels_last)


@register_test_case(
module_factory=lambda: ElementwiseCloneChannelsLastMemoryFormatModule())
def ElementwiseCloneChannelsLastMemoryFormatModule_basic(
module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4, 5))


# ==============================================================================


class LiftFreshCopyModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit f58a610

Please sign in to comment.