Skip to content

Commit

Permalink
Prune xfail e2e LTC tests & fix bugs from functionalization pass (#1044)
Browse files Browse the repository at this point in the history
- Pruned number of xfailed e2e LTC tests from 305 to 134
  - Reviewed every failure to ensure the error genuinely warrants an xfail 
- Fixed bug where non-tensor outputs of LTC computation had `.to('cpu')` called, which caused a failure and inflated the xfail count
- Fixed bug with `HBC_basic` test where a constant tensor was created in its constructor without being declared as a buffer, which prevented the device from being updated when the parent `torch.nn.Module` got moved to the `lazy` device
  - Note that this test is still xfail'd due to some unsupported ops. Left a comment about some potential issues that may arise if it gets reenabled in the future 
- Updated autogen `GeneratedTorchOps.td` to reflect the latest set of supported ops
- Renamed `aten.zero.functionalization` to `aten.zero` to reflect upstream PyTorch changes
  • Loading branch information
henrytwo authored Jul 12, 2022
1 parent 4e00281 commit 7aec9a9
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 220 deletions.
177 changes: 3 additions & 174 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,60 +177,22 @@
LTC_XFAIL_SET = {
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"AddIntModule_basic",
"AllBoolFalseModule_basic",
"AllBoolTrueModule_basic",
"AnyBoolFalseModule_basic",
"AnyBoolTrueModule_basic",
"ArangeDtypeFloatModule_basic",
"ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic",
"ArangeFloatModule_basic",
"ArangeIntModule_basic",
"ArangeNegativeStartFloatModule_basic",
"ArangeNegativeStartIntModule_basic",
"ArangeStartFloatModule_basic",
"ArangeStartIntModule_basic",
"ArangeStartNegativeStepFloatModule_basic",
"ArangeStartNegativeStepIntModule_basic",
"ArangeStartStepFloatModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"AvgPool2dCeilModeTrueModule_basic",
"AvgPool2dDivisorOverrideModule_basic",
"AvgPool2dFloatModule_basic",
"AvgPool2dIntModule_basic",
"AvgPool2dStaticModule_basic",
"BernoulliFloatModule_basic",
"BernoulliModule_basic",
"BernoulliOnesModule_basic",
"BernoulliTensorModule_basic",
"BernoulliZerosModule_basic",
"BincountMinlengthModule_basic",
"BincountModule_basic",
"BincountStaticSizeModule_basic",
"BoolFloatConstantModule_basic",
"BoolFloatFalseModule_basic",
"BoolFloatTrueModule_basic",
"BoolIntConstantModule_basic",
"BoolIntFalseModule_basic",
"BoolIntTrueModule_basic",
"CeilFloatModule_basic",
"DivFloatModule_basic",
"DropoutTrainModule_basic",
"ElementwiseAtenLogicalOrOpBrodcastModule_basic",
"ElementwiseAtenLogicalOrOpDiffArgs1Module_basic",
"ElementwiseAtenLogicalOrOpDiffArgs2Module_basic",
"ElementwiseAtenLogicalOrOpDiffArgs3Module_basic",
"ElementwiseAtenLogicalOrOpModule_basic",
"ElementwiseAtenLogicalOrOpNegativeModule_basic",
"ElementwiseAtenLogicalOrOpRandomFloatModule_basic",
"ElementwiseAtenLogicalOrOpRandomModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampModule_basic",
"ElementwiseAtenFloorDivideBroadcastModule_basic",
"ElementwiseAtenFloorDivideModule_basic",
"ElementwiseWhereScalarModule_basic",
"ElementwiseWhereScalarOtherModule_basic",
"ElementwiseWhereScalarSelfModule_basic",
Expand All @@ -240,11 +202,6 @@
"EmptyLikeModule_falsePinMemory",
"EmptyLikeModule_float",
"EmptyLikeModule_int",
"EmptyModule_contiguous",
"EmptyModule_defaultDtype",
"EmptyModule_falsePinMemory",
"EmptyModule_float",
"EmptyModule_int",
"EqIntModule_basic",
"Fill_TensorFloat64WithFloat32_basic",
"Fill_TensorFloat64WithFloat64_basic",
Expand All @@ -257,12 +214,6 @@
"FullLikeModuleInt2DStatic_basic",
"FullLikeModuleInt2D_basic",
"FullLikeModuleInt3D_basic",
"FullModuleDefaultDtype_basic",
"FullModuleFalsePinMemory_basic",
"FullModuleFloat2D_basic",
"FullModuleFloat3D_basic",
"FullModuleInt2D_basic",
"FullModuleInt3D_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GtFloatIntModule_basic",
Expand Down Expand Up @@ -304,56 +255,13 @@
"IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl3DFloatAccumulateModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexSelectDynamicIndexSizeModule_basic",
"IndexSelectDynamicInputSizeModule_basic",
"IndexSelectDynamicModulebasic",
"IndexSelectSingleIdxModule_basic",
"IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexTensorModule3dInput_basic",
"IndexTensorModule_basic",
"MaskedFillScalarDefaultModule_basic",
"MaskedFillScalarFloatValueModule_basic",
"MaskedFillScalarIntValueModule_basic",
"Matmul_dot",
"Matmul_matvec",
"Matmul_vecmat",
"MaxPool2dCeilModeTrueModule_basic",
"MaxPool2dModule_basic",
"MaxPool2dStaticModule_basic",
"MaxPool2dWith3dInputModule_basic",
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
"MaxPool2dWithIndicesAllOnesModule_basic",
"MaxPool2dWithIndicesBackwardDynamic3DModule_basic",
"MaxPool2dWithIndicesBackwardDynamic4DModule_basic",
"MaxPool2dWithIndicesBackwardStatic3DModule_basic",
"MaxPool2dWithIndicesBackwardStatic4DModule_basic",
"MaxPool2dWithIndicesCeilModeTrueModule_basic",
"MaxPool2dWithIndicesFullSizeKernelModule_basic",
"MaxPool2dWithIndicesModule_basic",
"MaxPool2dWithIndicesNonDefaultDilationModule_basic",
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
"MaxPool2dWithIndicesNonDefaultParamsModule_basic",
"MaxPool2dWithIndicesNonDefaultStrideModule_basic",
"MaxPool2dWithIndicesStaticModule_basic",
"MaxPool2dWithIndicesWith3dInputModule_basic",
"MeanDimAllReduceKeepdimModule_basic",
"MeanDimAllReduceModule_basic",
"MeanDimDtypeModule_basic",
"MeanDimKeepdimModule_basic",
"MeanDimModule_basic",
"MeanDimNegativeModule_basic",
"MeanDtypeModule_basic",
"MeanDynamicSizesModule_basic",
"MeanModule_basic",
"MobilenetV3Module_basic",
"MulIntModule_basic",
"NativeBatchNorm1DModule_basic",
"NativeBatchNorm2DModule_basic",
"NativeBatchNorm3DModule_basic",
"NativeBatchNormNoneWeightModule_basic",
"NativeLayerNormDynamicModule_basic",
"NativeLayerNormModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"NewEmptyModuleDefaultDtype_basic",
Expand All @@ -371,69 +279,23 @@
"NewOnesModuleFloat3D_basic",
"NewOnesModuleInt2D_basic",
"NewOnesModuleInt3D_basic",
"NewZerosModuleDefaultDtype_basic",
"NewZerosModuleFalsePinMemory_basic",
"NewZerosModuleFloat2D_basic",
"NewZerosModuleFloat3D_basic",
"NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic",
"NllLossModuleBackward1DMeanWeight_basic",
"NllLossModuleBackward1DMean_basic",
"NllLossModuleBackward1DSumWeight_basic",
"NllLossModuleBackward1DSum_basic",
"NllLossModuleBackward1DWeight_basic",
"NllLossModuleBackward1D_basic",
"NllLossModuleBackwardMeanWeight_basic",
"NllLossModuleBackwardMean_basic",
"NllLossModuleBackwardSumWeight_basic",
"NllLossModuleBackwardSum_basic",
"NllLossModuleBackwardWeight_basic",
"NllLossModuleBackward_basic",
"NllLossModuleBackward_ignore_index",
"NllLossModule_1D_basic",
"NllLossModule_basic",
"NllLossModule_ignore_index_out_of_bounds_basic",
"NllLossModule_mean_basic",
"NllLossModule_sum_basic",
"NumelModule_basic",
"NumelZeroRankModule_basic",
"OnesLikeModule_defaultDtype",
"OnesLikeModule_falsePinMemory",
"OnesLikeModule_float",
"OnesLikeModule_int",
"OnesModuleDefaultDtype_basic",
"OnesModuleFalsePinMemory_basic",
"OnesModuleFloat_basic",
"OnesModuleInt_basic",
"QuantizedMLP_basic",
"RandLikeDtypeModule_basic",
"RandLikeModule_basic",
"ReduceMaxKeepDimReturnBoth_basic",
"ReduceMaxNegativeDim_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
"SelectIntModule_basic",
"SliceEndSleStartModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundEndIndexModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceSingleIdxModule_basic",
"SliceSizeTwoStepModule_basic",
"SliceStartEqEndModule_basic",
"SliceWholeTensorModule_basic",
"SqrtIntConstantModule_basic",
"SqrtIntModule_basic",
"StdBiasedModule_basic",
"StdUnbiasedModule_basic",
"SubFloatModule_basic",
"SubIntModule_basic",
"TModuleRank0_basic",
"TModuleRank1_basic",
"TableBatchEmbeddingModule_basic",
"TensorToBoolZeroRank_basic",
"TensorToBool_basic",
Expand All @@ -442,43 +304,10 @@
"TensorToIntZeroRank_basic",
"TensorToInt_basic",
"TensorsConcatModule_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
"TestMultipleTensorReturn_basic",
"Threshold1dFloatModule_basic",
"Threshold1dIntI32Module_basic",
"Threshold1dIntModule_basic",
"Threshold2dFloatModule_basic",
"Threshold2dIntModule_basic",
"Threshold3dFloatModule_basic",
"Threshold3dIntModule_basic",
"ThresholdBackward1dFloatModule_basic",
"ThresholdBackward1dIntModule_basic",
"ThresholdBackward1dMixedModule_basic",
"ThresholdBackward2dFloatModule_basic",
"ThresholdBackward2dIntModule_basic",
"ThresholdBackward2dMixedModule_basic",
"ThresholdBackward3dFloatModule_basic",
"ThresholdBackward3dIntModule_basic",
"ThresholdBackward3dMixedModule_basic",
"TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic",
"UniformModule_basic",
"UniformStaticModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"VarBiasedModule_basic",
"VarUnbiasedModule_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ZeroFloat32Module_basic",
"ZeroInt32Module_basic",
"ZeroInt64Module_basic",
"ZerosLikeModule_defaultDtype",
"ZerosLikeModule_falsePinMemory",
"ZerosLikeModule_float",
"ZerosLikeModule_int",
"ZerosModuleDefaultDtype_basic",
"ZerosModuleFalsePinMemory_basic",
"ZerosModuleFloat2D_basic",
"ZerosModuleFloat3D_basic",
"ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic",
}
34 changes: 6 additions & 28 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2346,12 +2346,12 @@ def Torch_AtenUnsqueeze_Op : Torch_Op<"aten.unsqueeze_", [
}];
}

def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [
def Torch_AtenZeroOp : Torch_Op<"aten.zero", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::zero.functional : (Tensor) -> (Tensor)`";
let summary = "Generated op for `aten::zero : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
Expand All @@ -2360,16 +2360,17 @@ def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenZeroFunctionalOp::parse(OpAsmParser &parser, OperationState &result) {
ParseResult AtenZeroOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenZeroFunctionalOp::print(OpAsmPrinter &printer) {
void AtenZeroOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::zero_ : (Tensor) -> (Tensor)`";
Expand Down Expand Up @@ -5564,6 +5565,7 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [

def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_unsafe_view : (Tensor, int[]) -> (Tensor)`";
Expand Down Expand Up @@ -6659,30 +6661,6 @@ def Torch_AtenAsStridedScatterOp : Torch_Op<"aten.as_strided_scatter", [
}];
}

def Torch_Aten_UnsafeViewCopyOp : Torch_Op<"aten._unsafe_view_copy", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_unsafe_view_copy : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$size
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_UnsafeViewCopyOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void Aten_UnsafeViewCopyOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
10 changes: 5 additions & 5 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,11 @@ class DecomposeAtenSelectIntOp : public OpRewritePattern<AtenSelectIntOp> {
} // namespace

namespace {
class DecomposeAtenZeroFunctionalOp
: public OpRewritePattern<AtenZeroFunctionalOp> {
class DecomposeAtenZeroOp
: public OpRewritePattern<AtenZeroOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenZeroFunctionalOp op,
LogicalResult matchAndRewrite(AtenZeroOp op,
PatternRewriter &rewriter) const override {
Value zero = rewriter.create<ConstantIntOp>(op.getLoc(),
rewriter.getI64IntegerAttr(0));
Expand Down Expand Up @@ -2200,8 +2200,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
patterns.add<DecomposeValsemVariantAtenBernoulliTensorOp>(context);
target.addIllegalOp<ValsemVariantAtenBernoulliTensorOp>();
patterns.add<DecomposeAtenZeroFunctionalOp>(context);
target.addIllegalOp<AtenZeroFunctionalOp>();
patterns.add<DecomposeAtenZeroOp>(context);
target.addIllegalOp<AtenZeroOp>();
patterns.add<DecomposeAtenRandLikeOp>(context);
target.addIllegalOp<AtenRandLikeOp>();
patterns.add<DecomposeAtenHardsigmoidOp>(context);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class ReduceNonValueSemanticOps : public RewritePattern {
newOp = rewriter.create<ValsemVariantAtenBernoulliTensorOp>(
loc, op->getResultTypes(), op->getOperands());
} else if (isa<AtenZero_Op>(op)) {
newOp = rewriter.create<AtenZeroFunctionalOp>(
newOp = rewriter.create<AtenZeroOp>(
loc, op->getResultTypes(), op->getOperands());
} else if (isa<AtenFill_ScalarOp>(op)) {
newOp = rewriter.create<ValsemVariantAtenFillScalarOp>(
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp,
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp,
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
Expand Down
3 changes: 0 additions & 3 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6057,9 +6057,6 @@ module {
func.func @"__torch_mlir_shape_fn.aten.zero"(%arg0: !torch.list<int>) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.zero.functional"(%arg0: !torch.list<int>) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.fill.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
return %arg0 : !torch.list<int>
}
Expand Down
Loading

0 comments on commit 7aec9a9

Please sign in to comment.