Skip to content

Commit

Permalink
Add dtype functions for ops that take dtype from 2nd operand
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Feb 18, 2023
1 parent 63945a2 commit 670ec73
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 10 deletions.
82 changes: 82 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7687,6 +7687,88 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.tuple<int, int>) -> !torch.int {\n"
" %str = torch.constant.str \"AssertionError: `self` cannot have float16 dtype\"\n"
" %int5 = torch.constant.int 5\n"
" %str_0 = torch.constant.str \"AssertionError: `self` cannot have integer dtype\"\n"
" %str_1 = torch.constant.str \"AssertionError: `self` cannot have complex dtype\"\n"
" %none = torch.constant.none\n"
" %str_2 = torch.constant.str \"AssertionError: `grad_output` and `self` must have the same dtype\"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %7 = torch.aten.ne.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %7 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %1#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.tuple<int, int>) -> !torch.int {\n"
" %str = torch.constant.str \"AssertionError: `self` cannot have float16 dtype\"\n"
" %int5 = torch.constant.int 5\n"
" %str_0 = torch.constant.str \"AssertionError: `self` cannot have integer dtype\"\n"
" %str_1 = torch.constant.str \"AssertionError: `self` cannot have complex dtype\"\n"
" %none = torch.constant.none\n"
" %str_2 = torch.constant.str \"AssertionError: `grad_output` and `self` must have the same dtype\"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %7 = torch.aten.ne.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %7 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %1#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.all\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" %int0 = torch.constant.int 0\n"
Expand Down
10 changes: 0 additions & 10 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,16 +672,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}

// Take dtype from second operand.
if (isa<AtenNllLossBackwardOp, AtenMaxPool2dWithIndicesBackwardOp>(op)) {
auto self = operands[1]->getValue();
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = self.dtype;
incorporateKnowledge(op->getResult(0), knowledge);
return;
}

// Dtype is always si64.
if (isa<AtenBincountOp>(op)) {
auto knowledge =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,36 @@ def prims〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
assert self_dtype != torch.float16
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_check_tensors_with_the_same_dtype(
None, [(3,), (3, 4)],
{torch.complex128, torch.complex64, torch.float16, torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool},
TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32)) +
[ErrorInvocation(TensorOfShape(3, dtype=torch.float32), TensorOfShape(3, 4, dtype=torch.float64), TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32)),
ErrorInvocation(TensorOfShape(3, dtype=torch.float64), TensorOfShape(3, 4, dtype=torch.float32), TensorOfShape(3, dtype=torch.int64), None, 0, 10, TensorOfShape(1, dtype=torch.float32))])
def aten〇nll_loss_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], reduction: int, ignore_index: int, total_weight_rank_dtype: Tuple[int, int]) -> int:
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
self_rank, self_dtype = self_rank_dtype
assert grad_output_dtype == self_dtype, "`grad_output` and `self` must have the same dtype"
assert not is_complex_dtype(self_dtype), "`self` cannot have complex dtype"
assert not is_integer_dtype(self_dtype), "`self` cannot have integer dtype"
assert self_dtype != torch.float16, "`self` cannot have float16 dtype"
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(
None, [(2, 4, 7, 6), (2, 4, 6, 5)],
{torch.complex128, torch.complex64, torch.float16, torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool},
[2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)) +
[ErrorInvocation(TensorOfShape(2, 4, 7, 6, dtype=torch.float32), TensorOfShape(2, 4, 6, 5, dtype=torch.float64), [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64)),
ErrorInvocation(TensorOfShape(2, 4, 7, 6, dtype=torch.float64), TensorOfShape(2, 4, 6, 5, dtype=torch.float32), [2, 2], [1, 1], [1, 1], [1, 1], False, TensorOfShape(2, 4, 7, 6, dtype=torch.int64))])
def aten〇max_pool2d_with_indices_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices_rank_dtype: Tuple[int, int]) -> int:
grad_output_rank, grad_output_dtype = grad_output_rank_dtype
self_rank, self_dtype = self_rank_dtype
assert grad_output_dtype == self_dtype, "`grad_output` and `self` must have the same dtype"
assert not is_complex_dtype(self_dtype), "`self` cannot have complex dtype"
assert not is_integer_dtype(self_dtype), "`self` cannot have integer dtype"
assert self_dtype != torch.float16, "`self` cannot have float16 dtype"
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇all〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down

0 comments on commit 670ec73

Please sign in to comment.