Skip to content

Commit

Permalink
Add dtype functions for ops that take dtype from 1st operand
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Feb 19, 2023
1 parent 63945a2 commit 9c87303
Show file tree
Hide file tree
Showing 8 changed files with 1,850 additions and 283 deletions.
11 changes: 6 additions & 5 deletions build_tools/python_deploy/build_linux_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -264,20 +264,21 @@ function test_in_tree() {
echo ":::: Check that update_torch_ods.sh has been run"
_check_file_not_changed_by ./build_tools/update_torch_ods.sh include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

# TODO: pass `crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed` cases after dtype transition is done
echo ":::: Run Linalg e2e integration tests"
python -m e2e_testing.main --config=linalg -v
python -m e2e_testing.main --config=linalg -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed ElementwiseClampModule_basic IouOfModule_basic NormalizeModule_basic

echo ":::: Run MHLO e2e integration tests"
python -m e2e_testing.main --config=mhlo -v
python -m e2e_testing.main --config=mhlo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed ElementwiseClampModule_basic IouOfModule_basic NormalizeModule_basic

echo ":::: Run TOSA e2e integration tests"
python -m e2e_testing.main --config=tosa -v
python -m e2e_testing.main --config=tosa -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed ElementwiseClampModule_basic IouOfModule_basic NormalizeModule_basic

echo ":::: Run Lazy Tensor Core e2e integration tests"
python -m e2e_testing.main --config=lazy_tensor_core -v
python -m e2e_testing.main --config=lazy_tensor_core -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed ElementwiseClampModule_basic IouOfModule_basic NormalizeModule_basic

echo ":::: Run TorchDynamo e2e integration tests"
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed DropoutTrainModule_basic RandnDtypeDeviceModule_basic ElementwiseClampModule_basic IouOfModule_basic NormalizeModule_basic
}

function setup_venv() {
Expand Down
3 changes: 1 addition & 2 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@

# Dtype function transition failures
"MobilenetV3Module_basic",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
}

MHLO_PASS_SET = {
Expand Down Expand Up @@ -651,6 +649,7 @@
"HardsigmoidRandomModule_basic",
"HardswishModule_basic",
"HardswishRandomModule_basic",
"ResNet18StaticModule_basic",
}

LTC_XFAIL_SET = {
Expand Down
1,350 changes: 1,245 additions & 105 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Large diffs are not rendered by default.

33 changes: 1 addition & 32 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,36 +642,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}

// Take dtype from first operand.
if (isa<CopyToValueTensorOp, CopyToNonValueTensorOp, AtenBatchNormOp,
AtenReluOp, AtenRelu6Op, AtenGeluOp, AtenCeilOp, AtenGeluBackwardOp,
AtenBitwiseNotOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopyOp, AtenCumsumOp,
AtenLayerNormOp, AtenClampOp, AtenClampMinOp, AtenClampMaxOp,
AtenNegOp, AtenFloorOp, Aten_SoftmaxBackwardDataOp, AtenDropoutOp,
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAbsOp, AtenThresholdOp, AtenSquareOp, AtenUniformOp,
AtenBernoulliOp, AtenBernoulli_FloatOp, AtenBernoulliTensorOp,
ValsemVariantAtenBernoulliFloatOp, AtenBernoulliTensorOp,
AtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp, AtenHardswishOp,
AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp, AtenMaxPool2dOp,
AtenAvgPool2dOp, AtenAdaptiveAvgPool2dOp, AtenFlattenUsingIntsOp,
AtenSqueezeOp, AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp,
Aten_UnsafeViewOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op,
AtenTransposeIntOp, AtenTOp, AtenPermuteOp, AtenIndexSelectOp,
AtenSelectIntOp, AtenSelectScatterOp, AtenNarrowOp, AtenSliceTensorOp,
AtenSliceScatterOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp,
AtenZero_Op, AtenIndexTensorOp, Aten_IndexPutImplOp, AtenIndexPutOp,
AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenPreluOp,
AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
AtenUpsampleNearest2dOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp,
AtenUpsampleNearest2dBackwardOp, AtenLeakyReluBackwardOp>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}

// Take dtype from second operand.
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
auto self = operands[1]->getValue();
Expand Down Expand Up @@ -1486,8 +1456,7 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
// the right thing forthose ops.
//
static bool allowsTypeRefinementOrIsSafeToRefine(Operation *op) {
return op->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>() ||
isa<CopyToNonValueTensorOp, CopyToValueTensorOp>(op);
return op->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>();
}

// Some operations have extra verification logic regarding the relationship
Expand Down
Loading

0 comments on commit 9c87303

Please sign in to comment.