Skip to content

Commit

Permalink
Add dtype functions for two tensor promotion ops
Browse files Browse the repository at this point in the history
This commit adds dtype functions for ops in RefineTypes under the
category of "Promote the two dtypes". The only ops not added here are
convolution ops, since they take an optional tensor argument, and the
dtype pipeline currently does not correctly handle that case. I will
add a follow up patch fixing this.

This commit also adds two helper functions that perform a very
thorough testing of dtype functions. The helper function
`_check_two_tensor_op` is able to independently test invalid input
dtypes and invalid output dtypes.

Lastly, this commit also XFAILs "MobilenetV3Module_basic".
  • Loading branch information
ramiro050 committed Jan 27, 2023
1 parent 83d4e89 commit 9e80fac
Show file tree
Hide file tree
Showing 6 changed files with 1,280 additions and 243 deletions.
974 changes: 861 additions & 113 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Large diffs are not rendered by default.

21 changes: 3 additions & 18 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,10 +693,9 @@ void TypeAnalysis::visitOperation(Operation *op,
}

// Promote the two dtypes assuming non-zero rank.
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenMvOp,
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp,
AtenMseLossOp>(op)) {
if (isa<AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp,
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
Expand All @@ -705,20 +704,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}

// Promote the two dtypes assuming possibly-zero rank.
if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp, AtenDivTensorOp,
AtenDivTensorModeOp, Aten__And__TensorOp, AtenMinimumOp,
AtenMaximumOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenThresholdBackwardOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultType(
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()},
getRankIsNonZeroArray(op->getOperands()));
incorporateKnowledge(op->getResult(0), knowledge);
return;
}

// Dtype is always float32, except for bfloat16, float64 and nullptr after
// promotion and assuming possible-zero rank.
if (isa<AtenAtan2Op>(op)) {
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ class SimplifyDtypeCalculationsPass
patterns.insert<DecomposePromoteDtypesOp>(context);
patterns.insert<RefineNumToTensorScalarOpType>(context);

PrimIfOp::getCanonicalizationPatterns(patterns, context);

// TODO: Debug visitation order to make this more efficient.
// A single linear scan should suffice.
GreedyRewriteConfig config;
Expand Down
Loading

0 comments on commit 9e80fac

Please sign in to comment.