Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dtype functions for ops that take dtype from 1st operand #1895

Merged
merged 1 commit into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build_tools/python_deploy/build_linux_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ function test_in_tree() {
python -m e2e_testing.main --config=lazy_tensor_core -v

echo ":::: Run TorchDynamo e2e integration tests"
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic RandnLikeModule_basic Matmul_dot
python -m e2e_testing.main --config=torchdynamo -v --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed RandnDtypeDeviceModule_basic RandnLikeModule_basic
}

function setup_venv() {
Expand Down
5 changes: 2 additions & 3 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@

# Dtype function transition failures
"MobilenetV3Module_basic",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
}

STABLEHLO_PASS_SET = {
Expand Down Expand Up @@ -706,7 +704,8 @@
"FullLikeModuleInt2DStatic_basic",
"FullModuleInt3D_basic",
"FullModuleFloat2D_basic",
"RepeatModule_basic"
"RepeatModule_basic",
"ResNet18StaticModule_basic",
}

LTC_XFAIL_SET = {
Expand Down
557 changes: 543 additions & 14 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Large diffs are not rendered by default.

36 changes: 1 addition & 35 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,39 +613,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}

// Take dtype from first operand.
if (isa<CopyToValueTensorOp, CopyToNonValueTensorOp, AtenBatchNormOp,
AtenReluOp, AtenRelu6Op, AtenGeluOp, AtenCeilOp, AtenGeluBackwardOp,
AtenBitwiseNotOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp,
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopyOp, AtenCumsumOp,
AtenLayerNormOp, AtenClampOp, AtenClampMinOp, AtenClampMaxOp,
AtenNegOp, AtenFloorOp, Aten_SoftmaxBackwardDataOp, AtenDropoutOp,
AtenTanhBackwardOp, AtenHardtanhBackwardOp,
Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenAbsOp,
AtenThresholdOp, AtenSquareOp, AtenUniformOp, AtenBernoulliOp,
AtenBernoulli_FloatOp, AtenBernoulliTensorOp,
ValsemVariantAtenBernoulliFloatOp, AtenBernoulliTensorOp,
AtenBernoulliPOp, AtenFillScalarOp, AtenHardsigmoidOp, AtenCloneOp,
AtenHardswishOp, AtenSiluOp, AtenHardtanhOp, AtenMaskedSelectOp,
AtenMaxPool2dOp, AtenAvgPool2dOp, AtenAdaptiveAvgPool2dOp,
AtenFlattenUsingIntsOp, AtenSqueezeOp, AtenSqueezeDimOp,
AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp, AtenReshapeOp,
Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp, AtenTOp,
AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp,
AtenSelectScatterOp, AtenNarrowOp, AtenSliceTensorOp,
AtenScatterReduceTwoOp, AtenSliceScatterOp, AtenGatherOp,
AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp, AtenRepeatOp,
AtenConstantPadNdOp, AtenPadOp, AtenZero_Op, AtenIndexTensorOp,
Aten_IndexPutImplOp, AtenIndexPutOp, AtenCopyOp, AtenZeroOp,
AtenIndexPutHackedTwinOp, AtenPreluOp, AtenMaskedFillScalarOp,
AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp,
AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
AtenUpsampleNearest2dOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp,
AtenUpsampleNearest2dBackwardOp, AtenLeakyReluBackwardOp>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}

// Dtype is always float32, except for bfloat16, float64 and nullptr after
// promotion and assuming possible-zero rank.
if (isa<AtenAtan2Op>(op)) {
Expand Down Expand Up @@ -1365,8 +1332,7 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) {
// the right thing forthose ops.
//
static bool allowsTypeRefinementOrIsSafeToRefine(Operation *op) {
return op->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>() ||
isa<CopyToNonValueTensorOp, CopyToValueTensorOp>(op);
return op->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>();
}

// Some operations have extra verification logic regarding the relationship
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,14 @@ FailureOr<Value> Torch::adjustFunctionArg(
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}

// !torch.union<int, float> is the type used for `Scalar` inputs. At
// compile time, such inputs will usually be resolved to an `int` or a `float`
// so we need to derefine to match the library function signature.
// !torch.union<int, float> or !torch.union<int, float, none> is the type used
// for (optional) `Scalar` inputs. At compile time, such inputs will usually
// be resolved to an `int` or a `float` so we need to derefine to match the
// library function signature.
if (auto unionType = desiredType.dyn_cast<Torch::UnionType>()) {
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
return containedType.isa<Torch::IntType, Torch::FloatType>();
return containedType
.isa<Torch::IntType, Torch::FloatType, Torch::NoneType>();
}))
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
Expand Down
Loading