From 7aec9a9570be0af928c416c802fbc1bc182074ca Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Tue, 12 Jul 2022 15:21:50 -0400 Subject: [PATCH] Prune xfail e2e LTC tests & fix bugs from functionalization pass (#1044) - Pruned number of xfailed e2e LTC tests from 305 to 134 - Reviewed every failure to ensure the error genuinely warrants an xfail - Fixed bug where non-tensor outputs of LTC computation had `.to('cpu')` called, which caused a failure and inflated the xfail count - Fixed bug with `HBC_basic` test where a constant tensor was created in its constructor without being declared as a buffer, which prevented the device from being updated when the parent `torch.nn.Module` got moved to the `lazy` device - Note that this test is still xfail'd due to some unsupported ops. Left a comment about some potential issues that may arise if it gets reenabled in the future - Updated autogen `GeneratedTorchOps.td` to reflect the latest set of supported ops - Renamed `aten.zero.functionalization` to `aten.zero` to reflect upstream PyTorch changes --- e2e_testing/torchscript/xfail_sets.py | 177 +----------------- .../Dialect/Torch/IR/GeneratedTorchOps.td | 34 +--- .../Torch/Transforms/DecomposeComplexOps.cpp | 10 +- .../Torch/Transforms/ReduceOpVariants.cpp | 2 +- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 3 - .../jit_ir/build_tools/torch_ods_gen.py | 3 +- .../histogram_binning_calibration.py | 5 +- .../torchscript/configs/lazy_tensor_core.py | 12 +- test/Dialect/Torch/decompose-complex-ops.mlir | 6 +- 10 files changed, 34 insertions(+), 220 deletions(-) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 93e5dd60a98d..df5c2fd41d2f 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -177,60 +177,22 @@ LTC_XFAIL_SET = { "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AddIntModule_basic", - "AllBoolFalseModule_basic", - "AllBoolTrueModule_basic", - "AnyBoolFalseModule_basic", - "AnyBoolTrueModule_basic", - "ArangeDtypeFloatModule_basic", - "ArangeDtypeIntModule_basic", - "ArangeFalsePinMemoryModule_basic", - "ArangeFloatModule_basic", - "ArangeIntModule_basic", - "ArangeNegativeStartFloatModule_basic", - "ArangeNegativeStartIntModule_basic", - "ArangeStartFloatModule_basic", - "ArangeStartIntModule_basic", - "ArangeStartNegativeStepFloatModule_basic", - "ArangeStartNegativeStepIntModule_basic", - "ArangeStartStepFloatModule_basic", - "ArangeStartStepIntModule_basic", - "ArangeZeroElementOutputModule_basic", - "AvgPool2dCeilModeTrueModule_basic", - "AvgPool2dDivisorOverrideModule_basic", - "AvgPool2dFloatModule_basic", - "AvgPool2dIntModule_basic", - "AvgPool2dStaticModule_basic", "BernoulliFloatModule_basic", "BernoulliModule_basic", - "BernoulliOnesModule_basic", "BernoulliTensorModule_basic", - "BernoulliZerosModule_basic", "BincountMinlengthModule_basic", "BincountModule_basic", "BincountStaticSizeModule_basic", - "BoolFloatConstantModule_basic", "BoolFloatFalseModule_basic", "BoolFloatTrueModule_basic", - "BoolIntConstantModule_basic", "BoolIntFalseModule_basic", "BoolIntTrueModule_basic", "CeilFloatModule_basic", "DivFloatModule_basic", "DropoutTrainModule_basic", - "ElementwiseAtenLogicalOrOpBrodcastModule_basic", - "ElementwiseAtenLogicalOrOpDiffArgs1Module_basic", - "ElementwiseAtenLogicalOrOpDiffArgs2Module_basic", - "ElementwiseAtenLogicalOrOpDiffArgs3Module_basic", - "ElementwiseAtenLogicalOrOpModule_basic", - "ElementwiseAtenLogicalOrOpNegativeModule_basic", - "ElementwiseAtenLogicalOrOpRandomFloatModule_basic", - "ElementwiseAtenLogicalOrOpRandomModule_basic", - "ElementwiseClampMaxModule_basic", - "ElementwiseClampMinModule_basic", - "ElementwiseClampModule_basic", + "ElementwiseAtenFloorDivideBroadcastModule_basic", + "ElementwiseAtenFloorDivideModule_basic", "ElementwiseWhereScalarModule_basic", "ElementwiseWhereScalarOtherModule_basic", "ElementwiseWhereScalarSelfModule_basic", @@ -240,11 +202,6 @@ "EmptyLikeModule_falsePinMemory", "EmptyLikeModule_float", "EmptyLikeModule_int", - "EmptyModule_contiguous", - "EmptyModule_defaultDtype", - "EmptyModule_falsePinMemory", - "EmptyModule_float", - "EmptyModule_int", "EqIntModule_basic", "Fill_TensorFloat64WithFloat32_basic", "Fill_TensorFloat64WithFloat64_basic", @@ -257,12 +214,6 @@ "FullLikeModuleInt2DStatic_basic", "FullLikeModuleInt2D_basic", "FullLikeModuleInt3D_basic", - "FullModuleDefaultDtype_basic", - "FullModuleFalsePinMemory_basic", - "FullModuleFloat2D_basic", - "FullModuleFloat3D_basic", - "FullModuleInt2D_basic", - "FullModuleInt3D_basic", "GeFloatIntModule_basic", "GeFloatModule_basic", "GtFloatIntModule_basic", @@ -304,56 +255,13 @@ "IndexPutImpl2DFloatNonAccumulateModule_basic", "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", - "IndexSelectDynamicIndexSizeModule_basic", - "IndexSelectDynamicInputSizeModule_basic", - "IndexSelectDynamicModulebasic", - "IndexSelectSingleIdxModule_basic", - "IndexSelectTwoIdxModule_basic", - "IndexSelectWholeDimensionModule_basic", - "IndexSelectWholeTensorModule_basic", + "IndexTensorModule3dInput_basic", "IndexTensorModule_basic", - "MaskedFillScalarDefaultModule_basic", - "MaskedFillScalarFloatValueModule_basic", - "MaskedFillScalarIntValueModule_basic", "Matmul_dot", "Matmul_matvec", "Matmul_vecmat", - "MaxPool2dCeilModeTrueModule_basic", - "MaxPool2dModule_basic", - "MaxPool2dStaticModule_basic", - "MaxPool2dWith3dInputModule_basic", - "MaxPool2dWithIndicesAllNegativeValuesModule_basic", - "MaxPool2dWithIndicesAllOnesModule_basic", - "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", - "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", - "MaxPool2dWithIndicesBackwardStatic3DModule_basic", - "MaxPool2dWithIndicesBackwardStatic4DModule_basic", - "MaxPool2dWithIndicesCeilModeTrueModule_basic", - "MaxPool2dWithIndicesFullSizeKernelModule_basic", - "MaxPool2dWithIndicesModule_basic", - "MaxPool2dWithIndicesNonDefaultDilationModule_basic", - "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", - "MaxPool2dWithIndicesNonDefaultParamsModule_basic", - "MaxPool2dWithIndicesNonDefaultStrideModule_basic", - "MaxPool2dWithIndicesStaticModule_basic", - "MaxPool2dWithIndicesWith3dInputModule_basic", - "MeanDimAllReduceKeepdimModule_basic", - "MeanDimAllReduceModule_basic", - "MeanDimDtypeModule_basic", - "MeanDimKeepdimModule_basic", - "MeanDimModule_basic", - "MeanDimNegativeModule_basic", - "MeanDtypeModule_basic", - "MeanDynamicSizesModule_basic", - "MeanModule_basic", "MobilenetV3Module_basic", "MulIntModule_basic", - "NativeBatchNorm1DModule_basic", - "NativeBatchNorm2DModule_basic", - "NativeBatchNorm3DModule_basic", - "NativeBatchNormNoneWeightModule_basic", - "NativeLayerNormDynamicModule_basic", - "NativeLayerNormModule_basic", "NeFloatIntModule_basic", "NeIntModule_basic", "NewEmptyModuleDefaultDtype_basic", @@ -371,69 +279,23 @@ "NewOnesModuleFloat3D_basic", "NewOnesModuleInt2D_basic", "NewOnesModuleInt3D_basic", - "NewZerosModuleDefaultDtype_basic", - "NewZerosModuleFalsePinMemory_basic", - "NewZerosModuleFloat2D_basic", - "NewZerosModuleFloat3D_basic", - "NewZerosModuleInt2D_basic", - "NewZerosModuleInt3D_basic", - "NllLossModuleBackward1DMeanWeight_basic", - "NllLossModuleBackward1DMean_basic", - "NllLossModuleBackward1DSumWeight_basic", - "NllLossModuleBackward1DSum_basic", - "NllLossModuleBackward1DWeight_basic", - "NllLossModuleBackward1D_basic", - "NllLossModuleBackwardMeanWeight_basic", - "NllLossModuleBackwardMean_basic", - "NllLossModuleBackwardSumWeight_basic", - "NllLossModuleBackwardSum_basic", - "NllLossModuleBackwardWeight_basic", - "NllLossModuleBackward_basic", - "NllLossModuleBackward_ignore_index", - "NllLossModule_1D_basic", - "NllLossModule_basic", - "NllLossModule_ignore_index_out_of_bounds_basic", - "NllLossModule_mean_basic", - "NllLossModule_sum_basic", - "NumelModule_basic", - "NumelZeroRankModule_basic", "OnesLikeModule_defaultDtype", "OnesLikeModule_falsePinMemory", "OnesLikeModule_float", "OnesLikeModule_int", - "OnesModuleDefaultDtype_basic", - "OnesModuleFalsePinMemory_basic", - "OnesModuleFloat_basic", - "OnesModuleInt_basic", "QuantizedMLP_basic", "RandLikeDtypeModule_basic", "RandLikeModule_basic", - "ReduceMaxKeepDimReturnBoth_basic", - "ReduceMaxNegativeDim_basic", - "ReshapeAliasCollapseModule_basic", - "ReshapeAliasExpandModule_basic", - "ReturnThreeTensorFloat32_basic", - "ReturnTwoTensorF32I64_basic", "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", - "SelectIntModule_basic", "SliceEndSleStartModule_basic", - "SliceNegIdxModule_basic", - "SliceOutOfLowerBoundEndIndexModule_basic", - "SliceOutOfLowerBoundStartIndexModule_basic", "SliceOutOfUpperBoundIndexModule_basic", - "SliceSingleIdxModule_basic", - "SliceSizeTwoStepModule_basic", "SliceStartEqEndModule_basic", - "SliceWholeTensorModule_basic", - "SqrtIntConstantModule_basic", "SqrtIntModule_basic", "StdBiasedModule_basic", "StdUnbiasedModule_basic", "SubFloatModule_basic", "SubIntModule_basic", - "TModuleRank0_basic", - "TModuleRank1_basic", "TableBatchEmbeddingModule_basic", "TensorToBoolZeroRank_basic", "TensorToBool_basic", @@ -442,43 +304,10 @@ "TensorToIntZeroRank_basic", "TensorToInt_basic", "TensorsConcatModule_basic", - "TestMultipleTensorAndPrimitiveTypesReturn_basic", - "TestMultipleTensorReturn_basic", - "Threshold1dFloatModule_basic", - "Threshold1dIntI32Module_basic", - "Threshold1dIntModule_basic", - "Threshold2dFloatModule_basic", - "Threshold2dIntModule_basic", - "Threshold3dFloatModule_basic", - "Threshold3dIntModule_basic", - "ThresholdBackward1dFloatModule_basic", - "ThresholdBackward1dIntModule_basic", - "ThresholdBackward1dMixedModule_basic", - "ThresholdBackward2dFloatModule_basic", - "ThresholdBackward2dIntModule_basic", - "ThresholdBackward2dMixedModule_basic", - "ThresholdBackward3dFloatModule_basic", - "ThresholdBackward3dIntModule_basic", - "ThresholdBackward3dMixedModule_basic", - "TorchPrimLoopForLikeModule_basic", - "TorchPrimLoopWhileLikeModule_basic", "UniformModule_basic", "UniformStaticModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "VarBiasedModule_basic", "VarUnbiasedModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", - "ZeroFloat32Module_basic", - "ZeroInt32Module_basic", - "ZeroInt64Module_basic", - "ZerosLikeModule_defaultDtype", - "ZerosLikeModule_falsePinMemory", - "ZerosLikeModule_float", - "ZerosLikeModule_int", - "ZerosModuleDefaultDtype_basic", - "ZerosModuleFalsePinMemory_basic", - "ZerosModuleFloat2D_basic", - "ZerosModuleFloat3D_basic", - "ZerosModuleInt2D_basic", - "ZerosModuleInt3D_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c1a21a1e2ccf..f8ca01af6a7a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -2346,12 +2346,12 @@ def Torch_AtenUnsqueeze_Op : Torch_Op<"aten.unsqueeze_", [ }]; } -def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [ +def Torch_AtenZeroOp : Torch_Op<"aten.zero", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::zero.functional : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::zero : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -2360,16 +2360,17 @@ def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenZeroFunctionalOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenZeroOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenZeroFunctionalOp::print(OpAsmPrinter &printer) { + void AtenZeroOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [ + IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { let summary = "Generated op for `aten::zero_ : (Tensor) -> (Tensor)`"; @@ -5564,6 +5565,7 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [ def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { let summary = "Generated op for `aten::_unsafe_view : (Tensor, int[]) -> (Tensor)`"; @@ -6659,30 +6661,6 @@ def Torch_AtenAsStridedScatterOp : Torch_Op<"aten.as_strided_scatter", [ }]; } -def Torch_Aten_UnsafeViewCopyOp : Torch_Op<"aten._unsafe_view_copy", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::_unsafe_view_copy : (Tensor, int[]) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$size - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult Aten_UnsafeViewCopyOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void Aten_UnsafeViewCopyOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 4f86e3794c64..fd96e2ea1c2a 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -222,11 +222,11 @@ class DecomposeAtenSelectIntOp : public OpRewritePattern { } // namespace namespace { -class DecomposeAtenZeroFunctionalOp - : public OpRewritePattern { +class DecomposeAtenZeroOp + : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenZeroFunctionalOp op, + LogicalResult matchAndRewrite(AtenZeroOp op, PatternRewriter &rewriter) const override { Value zero = rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(0)); @@ -2200,8 +2200,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); - patterns.add(context); - target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 0b6d6cad69b7..79ffea3dc10d 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -184,7 +184,7 @@ class ReduceNonValueSemanticOps : public RewritePattern { newOp = rewriter.create( loc, op->getResultTypes(), op->getOperands()); } else if (isa(op)) { - newOp = rewriter.create( + newOp = rewriter.create( loc, op->getResultTypes(), op->getOperands()); } else if (isa(op)) { newOp = rewriter.create( diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 8774c77c013e..16f1f12b4e4e 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -640,7 +640,7 @@ ChangeResult TypeAnalyzer::visitOperation( AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, - AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroFunctionalOp, + AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) { return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 0fd3d672176c..c32601065116 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6057,9 +6057,6 @@ module { func.func @"__torch_mlir_shape_fn.aten.zero"(%arg0: !torch.list) -> !torch.list { return %arg0 : !torch.list } - func.func @"__torch_mlir_shape_fn.aten.zero.functional"(%arg0: !torch.list) -> !torch.list { - return %arg0 : !torch.list - } func.func @"__torch_mlir_shape_fn.aten.fill.Scalar"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { return %arg0 : !torch.list } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index d17eb72e75be..954648df1702 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -288,7 +288,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::square : (Tensor) -> (Tensor)", "aten::unsqueeze : (Tensor, int) -> (Tensor)", - "aten::zero.functional : (Tensor) -> (Tensor)", + "aten::zero : (Tensor) -> (Tensor)", ]: emit_with_mutating_variants(key) # Elementwise tensor compute ops that don't have the standard mutating @@ -492,7 +492,6 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)") emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)") emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)") - emit("aten::_unsafe_view_copy : (Tensor, int[]) -> (Tensor)") # Dict ops. diff --git a/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py b/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py index 8a176ce00ec1..82034d7719fc 100644 --- a/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py +++ b/python/torch_mlir_e2e_test/test_suite/histogram_binning_calibration.py @@ -38,7 +38,7 @@ def __init__(self): torch.empty([_num_interval], dtype=torch.float64).fill_(0.0), ) self.register_buffer("_bin_ids", torch.arange(_num_interval)) - self.positive_weight = torch.tensor([0.4]) + self.register_buffer("positive_weight", torch.tensor([0.4])) self.bin_ctr_in_use_after = 0 self.bin_ctr_weight_value = 0.9995 self.oneminusbin_ctr_weight_value = 0.0005 @@ -54,6 +54,9 @@ def __init__(self): def forward(self, segment_value, segment_lengths, logit): origin_prediction = torch.sigmoid( logit + torch.log(self.positive_weight)) + # TODO: If in the future this test is removed from xfail for LTC, we will probably hit some device related + # issues below when new tensors are created on the CPU, which is currently the default behaviour. + # The solution would be to move these tensors to ensure they are on the same device as the existing ones. dense_segment_value = torch.zeros(logit.numel(), dtype=torch.int32) validoffsets = torch.gt( segment_lengths[1:self._num_logits+1], segment_lengths[0:self._num_logits]) diff --git a/python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py b/python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py index 9c5b90cda84a..95598a4476c5 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py +++ b/python/torch_mlir_e2e_test/torchscript/configs/lazy_tensor_core.py @@ -5,9 +5,16 @@ import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend import torch +from torch.utils._pytree import tree_map + from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem +def to_device(device): + """Returns a lambda that maps `torch.Tensor` objects to `device`, and ignores other types""" + return lambda e: e.to(device) if isinstance(e, torch.Tensor) else e + + class LazyTensorCoreTestConfig(TestConfig): """TestConfig that runs torch.nn.Module thru the Lazy Tensor Core frontend for Torch MLIR""" @@ -23,12 +30,13 @@ def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: for item in trace: # We need to move all the inputs to the lazy device before running in LTC. - lazy_inputs = [arg.to('lazy') for arg in item.inputs] + lazy_inputs = tree_map(to_device('lazy'), item.inputs) output = getattr(artifact, item.symbol)(*lazy_inputs) + cpu_outputs = tree_map(to_device('cpu'), output) result.append( TraceItem(symbol=item.symbol, inputs=item.inputs, - output=output.to('cpu'))) + output=cpu_outputs)) return result diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index c97a31f4f241..f343b18c435f 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -829,13 +829,13 @@ func.func @torch.aten.dropout$train(%arg0: !torch.vtensor<[?,?],f32>) -> !torch. } // ----- -// CHECK-LABEL: func.func @torch.aten.zero.functional( +// CHECK-LABEL: func.func @torch.aten.zero( // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ZERO:.*]] = torch.constant.int 0 // CHECK: %[[OUT:.*]] = torch.valsem.aten.fill.Scalar %[[INP]], %[[ZERO]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> // CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.zero.functional(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %0 = torch.aten.zero.functional %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> +func.func @torch.aten.zero(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.zero %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> }