Skip to content

Commit

Permalink
Add dtype functions for two tensor promotion ops
Browse files Browse the repository at this point in the history
This commit adds dtype functions for ops in RefineTypes under the
category of "Promote the two dtypes". The only ops not added here are
convolution ops, since they take an optional tensor argument, and the
dtype pipeline currently does not correctly handle that case. I will
add a follow up patch fixing this.

This commit also adds two helper functions that perform a very
thorough testing of dtype functions. The helper function
`_check_two_tensor_op` is able to independently test invalid input
dtypes and invalid output dtypes.

Lastly, this commit also XFAILs "MobilenetV3Module_basic".
  • Loading branch information
ramiro050 committed Jan 27, 2023
1 parent 83d4e89 commit 517dc05
Show file tree
Hide file tree
Showing 8 changed files with 1,285 additions and 267 deletions.
5 changes: 5 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@
"ElementwisePreluModule_basic",
# error: op lowering missing. Issue: https://github.com/llvm/torch-mlir/issues/1792
"StdCorrectionKeepDimModule_basic",

# Dtype function transition failures
"MobilenetV3Module_basic",
"ResNet18Module_basic",
"ResNet18StaticModule_basic",
}

MHLO_PASS_SET = {
Expand Down
974 changes: 861 additions & 113 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Large diffs are not rendered by default.

21 changes: 3 additions & 18 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,10 +693,9 @@ void TypeAnalysis::visitOperation(Operation *op,
}

// Promote the two dtypes assuming non-zero rank.
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp, AtenMvOp,
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp,
AtenMseLossOp>(op)) {
if (isa<AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp,
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
Expand All @@ -705,20 +704,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}

// Promote the two dtypes assuming possibly-zero rank.
if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp, AtenDivTensorOp,
AtenDivTensorModeOp, Aten__And__TensorOp, AtenMinimumOp,
AtenMaximumOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenThresholdBackwardOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultType(
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()},
getRankIsNonZeroArray(op->getOperands()));
incorporateKnowledge(op->getResult(0), knowledge);
return;
}

// Dtype is always float32, except for bfloat16, float64 and nullptr after
// promotion and assuming possible-zero rank.
if (isa<AtenAtan2Op>(op)) {
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ class SimplifyDtypeCalculationsPass
patterns.insert<DecomposePromoteDtypesOp>(context);
patterns.insert<RefineNumToTensorScalarOpType>(context);

PrimIfOp::getCanonicalizationPatterns(patterns, context);

// TODO: Debug visitation order to make this more efficient.
// A single linear scan should suffice.
GreedyRewriteConfig config;
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@

from .registry import Registry

def is_integer_dtype(dtype: int) -> bool:
return dtype in [torch.bool, torch.uint8, torch.int8,
torch.int16, torch.int32, torch.int64]

def is_complex_dtype(dtype: int) -> bool:
return dtype in [torch.complex64, torch.complex128]

def is_float_dtype(dtype: int) -> bool:
return dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64]

def get_dtype_of_scalar(scalar: Union[int, float]) -> int:
# This is hacky. `NumToTensor` is the only PyTorch op for scalars
# that when `jit.script`ed converts a float scalar to a tensor
Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir_e2e_test/test_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"QuantizedMLP_basic",
"NormalizeModule_basic",
"MobilenetV3Module_basic",
}

def register_all_tests():
Expand Down
24 changes: 0 additions & 24 deletions test/Dialect/Torch/refine-types-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -235,30 +235,6 @@ func.func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>,
return %ret : !torch.tensor
}

// -----
// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Matrix(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
return %0 : !torch.tensor
}

// -----
// CHECK-LABEL: func.func @torch.aten.Matmul.Broadcast.Vector(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<*,f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<*,f32>) -> !torch.tensor {
// CHECK: %[[MUL:.*]] = torch.aten.matmul %[[LHS]], %[[RHS]] : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[MUL]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.Matmul.Broadcast.Vector(%arg0: !torch.vtensor<*,f32>, %arg1: !torch.vtensor<*,f32>) -> !torch.tensor {
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<*,f32>, !torch.vtensor<*,f32> -> !torch.tensor
return %0 : !torch.tensor
}

// -----
// CHECK-LABEL: func.func @torch.aten.to.dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor
Expand Down

0 comments on commit 517dc05

Please sign in to comment.