Skip to content

Commit

Permalink
Change dtype functions interface to take ints tuple for each tensor
Browse files Browse the repository at this point in the history
The original design for the dtype functions outlined in
#1462 was unable to properly
handle ops that take optional tensors as an input when the optional
tensor has a value of None. By the time the op gets imported into
torch-mlir, if an optional value is None, all information about the
original type is lost from the op type signature, preventing
torch-mlir from knowing if a value of None was from an optional tensor
or not, which was crucial in the original design since each tensor
argument must be turned into two separate arguments for the dtype
function.

This commit changes the interface to dtype functions such that each
tensor turns into a tuple of two ints, the first representing the rank
of the tensor and the second the dtype of the tensor. Since now there
is a one-to-one correspondence between the operands of an op and the
operands of its dtype function, there is no ambiguity about which
operand of the op corresponds with which operand of the dtype
function.

To test the implementation, this commit defines dtype functions for
the convolution ops, all of which take one optional tensor as an
argument.
  • Loading branch information
ramiro050 committed Feb 13, 2023
1 parent 981ac88 commit ae86968
Show file tree
Hide file tree
Showing 10 changed files with 942 additions and 551 deletions.
2 changes: 1 addition & 1 deletion docs/adding_abstract_interpretation_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ We will use the example of adding support for the `torch.aten.tanh` op.
function signatures are:

- `def aten〇tanh〡shape(self: List[int]) -> List[int]:`
- `def aten〇tanh〡dtype(self_rank: int, self_dtype: int) -> int:`
- `def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:`

Note the use of `` as a separator since `.` or `::` aren't legal
in a Python identifier.
Expand Down
1 change: 0 additions & 1 deletion e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,6 @@
"ElementwiseFlattenBroadcastModule_basic",
"SquareModule_basic",
"MaxPool2dStaticModule_basic",
"ResNet18StaticModule_basic",
"ReduceAmaxKeepDim_basic",
"NativeLayerNormModule4D_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
Expand Down
1,019 changes: 651 additions & 368 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Large diffs are not rendered by default.

12 changes: 0 additions & 12 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,18 +692,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}

// Promote the two dtypes assuming non-zero rank.
if (isa<AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp,
AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()});
incorporateKnowledge(op->getResult(0), knowledge);
return;
}

// Dtype is always float32, except for bfloat16, float64 and nullptr after
// promotion and assuming possible-zero rank.
if (isa<AtenAtan2Op>(op)) {
Expand Down
77 changes: 16 additions & 61 deletions lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,55 +19,25 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

static bool isTensorTypeOrWrappedTensorType(Type type) {
// Allowing tuples as arguments to dtype calculation functions can cause
// issues. For example, if an argument is a tuple of tensors and ints, there
// would be no way of differentiating the original ints from the ints created
// to represent the dtype and rank of the tensors. Therefore, to avoid this
// and keep things simple, the tuple type is not allowed. This works well in
// practice, since PyTorch op signatures don't seem to take tuples as inputs.
assert(!type.isa<Torch::TupleType>() &&
"dtype calculation functions are expected to not have tuples of "
"tensors as arguments");

if (type.isa<Torch::BaseTensorType>())
return true;

if (auto optionalType = type.dyn_cast<Torch::OptionalType>()) {
return isTensorTypeOrWrappedTensorType(optionalType.getContainedType());
} else if (auto listType = type.dyn_cast<Torch::ListType>()) {
return isTensorTypeOrWrappedTensorType(listType.getContainedType());
} else {
return false;
}
}

// Massage the op operands to match the dtype function signature.
// The dtype function generally takes the same operands as the op, with a few
// systematic modifications, such as replacing tensors with a rank and dtype
// argument.
// systematic modifications, such as replacing each tensor with a tuple of
// its rank and dtype.
static FailureOr<SmallVector<Value>>
dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
ValueRange originalOperands, func::FuncOp dtypeFunc) {
// Turns a tensor operand into an operand representing the rank of the tensor
auto rankArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
Type desiredType) -> Value {
if (desiredType.isa<Torch::IntType>() &&
operand.getType().isa<Torch::BaseTensorType>()) {
auto sizeListType =
Torch::ListType::get(Torch::IntType::get(b.getContext()));
Value size = b.create<AtenSizeOp>(loc, sizeListType, operand);
return b.create<AtenLenTOp>(loc, desiredType, size);
}
return operand;
};

// Turns a tensor operand into an operand representing the dtype of the tensor
// Turn every tensor into a tuple of (tensor_rank, tensor_dtype)
auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
Type desiredType) -> Value {
if (desiredType.isa<Torch::IntType>() &&
if (desiredType.isa<Torch::TupleType>() &&
operand.getType().isa<Torch::BaseTensorType>()) {
return b.create<PrimDtypeOp>(loc, desiredType, operand);
Type intType = Torch::IntType::get(b.getContext());
Type sizeListType = Torch::ListType::get(intType);
Value size = b.create<AtenSizeOp>(loc, sizeListType, operand);
Value rank = b.create<AtenLenTOp>(loc, intType, size);
Value dtype = b.create<PrimDtypeOp>(loc, intType, operand);
return b.create<PrimTupleConstructOp>(loc, desiredType,
ArrayRef{rank, dtype});
}
return operand;
};
Expand All @@ -79,26 +49,11 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
"`dtypeFunc` should have at least one argument for each argument in "
"`originalOperands`");
Type desiredType = desiredTypes.front();
if (isTensorTypeOrWrappedTensorType(operand.getType())) {
assert(desiredTypes.size() >= 2 &&
"`dtypeFunc` should have two arguments for each tensor argument "
"in `originalOperands`");
FailureOr<Value> rankArg, dtypeArg;
if (failed(rankArg = adjustFunctionArg(b, loc, operand, desiredType,
rankArgAdjuster)))
return failure();
desiredTypes = desiredTypes.drop_front();
desiredType = desiredTypes.front();
if (failed(dtypeArg = adjustFunctionArg(b, loc, operand, desiredType,
dtypeArgAdjuster)))
return failure();
dtypeFuncArgs.append({*rankArg, *dtypeArg});
} else {
FailureOr<Value> otherArg;
if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType)))
return failure();
dtypeFuncArgs.push_back(*otherArg);
}
FailureOr<Value> otherArg;
if (failed(otherArg = adjustFunctionArg(b, loc, operand, desiredType,
dtypeArgAdjuster)))
return failure();
dtypeFuncArgs.push_back(*otherArg);
desiredTypes = desiredTypes.drop_front();
}

Expand Down
Loading

0 comments on commit ae86968

Please sign in to comment.