Skip to content

Commit

Permalink
Add dtype functions for "2/3 results take dtype from first operand" (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ramiro050 authored Mar 24, 2023
1 parent 2793c9f commit d3a49fd
Show file tree
Hide file tree
Showing 3 changed files with 397 additions and 50 deletions.
233 changes: 233 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9075,6 +9075,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %10 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%8, %9) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %10 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.convolution_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.bool, %arg8: !torch.list<int>, %arg9: !torch.int, %arg10: !torch.list<bool>) -> !torch.tuple<int, int, int> {\n"
" %false = torch.constant.bool false\n"
" %int5 = torch.constant.int 5\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.eq.int %2#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
" %8 = torch.aten.ne.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %8 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %7 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
" return %7 : !torch.tuple<int, int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.convolution_backward_overrideable\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.list<bool>) -> !torch.tuple<int, int, int> {\n"
" %false = torch.constant.bool false\n"
" %int5 = torch.constant.int 5\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.eq.int %2#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
" %9 = torch.aten.ne.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %9 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %7:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %8 = torch.prim.TupleConstruct %7#1, %7#1, %7#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
" return %8 : !torch.tuple<int, int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bincount\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.int) -> !torch.int {\n"
" %int7 = torch.constant.int 7\n"
" %int4 = torch.constant.int 4\n"
Expand Down Expand Up @@ -9588,6 +9665,162 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_forward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<int, int> {\n"
" %int4 = torch.constant.int 4\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.__isnot__ %arg2, %none : !torch.optional<tuple<int, int>>, !torch.none -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" %6 = torch.prim.unchecked_cast %arg2 : !torch.optional<tuple<int, int>> -> !torch.tuple<int, int>\n"
" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %8 = torch.aten.eq.int %7#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %8 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.eq.int %1#1, %int4 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %5 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.native_layer_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float) -> !torch.tuple<int, int, int> {\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.__isnot__ %arg2, %none : !torch.optional<tuple<int, int>>, !torch.none -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" %5 = torch.prim.unchecked_cast %arg2 : !torch.optional<tuple<int, int>> -> !torch.tuple<int, int>\n"
" %6:2 = torch.prim.TupleUnpack %5 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %7 = torch.aten.eq.int %0#1, %6#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %7 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.__isnot__ %arg3, %none : !torch.optional<tuple<int, int>>, !torch.none -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" %5 = torch.prim.unchecked_cast %arg3 : !torch.optional<tuple<int, int>> -> !torch.tuple<int, int>\n"
" %6:2 = torch.prim.TupleUnpack %5 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %7 = torch.aten.eq.int %0#1, %6#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %7 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
" return %4 : !torch.tuple<int, int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.native_batch_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple<int, int, int> {\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.__isnot__ %arg1, %none : !torch.optional<tuple<int, int>>, !torch.none -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" %7 = torch.prim.unchecked_cast %arg1 : !torch.optional<tuple<int, int>> -> !torch.tuple<int, int>\n"
" %8:2 = torch.prim.TupleUnpack %7 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %9 = torch.aten.eq.int %0#1, %8#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %9 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.__isnot__ %arg2, %none : !torch.optional<tuple<int, int>>, !torch.none -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" %7 = torch.prim.unchecked_cast %arg2 : !torch.optional<tuple<int, int>> -> !torch.tuple<int, int>\n"
" %8:2 = torch.prim.TupleUnpack %7 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %9 = torch.aten.eq.int %0#1, %8#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %9 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.__isnot__ %arg3, %none : !torch.optional<tuple<int, int>>, !torch.none -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" %7 = torch.prim.unchecked_cast %arg3 : !torch.optional<tuple<int, int>> -> !torch.tuple<int, int>\n"
" %8:2 = torch.prim.TupleUnpack %7 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %9 = torch.aten.eq.int %0#1, %8#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %9 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.__isnot__ %arg4, %none : !torch.optional<tuple<int, int>>, !torch.none -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" %7 = torch.prim.unchecked_cast %arg4 : !torch.optional<tuple<int, int>> -> !torch.tuple<int, int>\n"
" %8:2 = torch.prim.TupleUnpack %7 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %9 = torch.aten.eq.int %0#1, %8#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %9 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %6 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
" return %6 : !torch.tuple<int, int, int>\n"
" }\n"
"}\n"
"";
// clang-format on
Expand Down
50 changes: 0 additions & 50 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,22 +536,6 @@ static Type getPromotedResultScalarType(ArrayRef<Type> scalarTypes) {
return *result;
}

// Returns most generic type Type() if the tensor dtype is unknown.
static Type getPromotedResultDType(ValueKnowledge *tensor, Type scalarType) {
if (!tensor->dtype)
return Type();
torch_upstream::ResultTypeState state = {};
// No need to check if rank is zero for tensor because scalar uses
// wrappedResult which is a lower priority than both dimResult and zeroResult.
state = updateResultTypeState(tensor, /*rankIsNonZero=*/std::nullopt, state,
/*skipRankCheck=*/true);
state =
updateResultTypeState(getDefaultDtypeForTorchScalar(scalarType), state);
FailureOr<Type> result =
getTypeForScalarType(scalarType.getContext(), result_type(state));
return failed(result) ? Type() : *result;
}

static SmallVector<std::optional<bool>>
getRankIsNonZeroArray(ValueRange values) {
SmallVector<std::optional<bool>> rankIsNonZero;
Expand Down Expand Up @@ -684,40 +668,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}

// 2 results take dtype from first operand.
if (isa<AtenNllLossForwardOp>(op)) {
auto self = operands[0]->getValue();
auto result0Knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
result0Knowledge.dtype = self.dtype;
auto result1Knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
result1Knowledge.dtype = self.dtype;
incorporateKnowledge(op->getResult(0), result0Knowledge);
incorporateKnowledge(op->getResult(1), result1Knowledge);
return;
}

// 3 results take dtype from first operand.
if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp,
AtenConvolutionBackwardOp, AtenConvolutionBackwardOverrideableOp>(
op)) {
auto self = operands[0]->getValue();
auto result0Knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
result0Knowledge.dtype = self.dtype;
auto result1Knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
result1Knowledge.dtype = self.dtype;
auto result2Knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
result2Knowledge.dtype = self.dtype;
incorporateKnowledge(op->getResult(0), result0Knowledge);
incorporateKnowledge(op->getResult(1), result1Knowledge);
incorporateKnowledge(op->getResult(2), result1Knowledge);
return;
}

if (isa<AtenMaxPool2dWithIndicesOp>(op)) {
auto self = operands[0]->getValue();
auto result0Knowledge =
Expand Down
Loading

0 comments on commit d3a49fd

Please sign in to comment.