Skip to content

Commit

Permalink
Add dtype functions for type conversion ops
Browse files Browse the repository at this point in the history
  • Loading branch information
ramiro050 committed Apr 24, 2023
1 parent 5c3f71d commit b20db8d
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 58 deletions.
29 changes: 29 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9832,6 +9832,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.to.dtype\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
" return %arg1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.nvprims.convert_element_type\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
" return %arg1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.to.dtype_layout\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.bool, %arg7: !torch.optional<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" torch.prim.If.yield %0#1 : !torch.int\n"
" } else {\n"
" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %3 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.to.device\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.Device, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool, %arg5: !torch.optional<int>) -> !torch.int {\n"
" return %arg2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.to.other\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.type_as\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
"}\n"
"";
// clang-format on
Expand Down
29 changes: 0 additions & 29 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,35 +648,6 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}

if (auto toDtype = dyn_cast<AtenToDtypeOp>(op)) {
visitAtenToDtypeLikeOp<AtenToDtypeOp>(toDtype, operands);
return;
}

if (auto primsConvertElementType = dyn_cast<PrimsConvertElementTypeOp>(op)) {
visitAtenToDtypeLikeOp<PrimsConvertElementTypeOp>(primsConvertElementType,
operands);
return;
}

if (auto toDtypeLayout = dyn_cast<AtenToDtypeLayoutOp>(op)) {
visitAtenToDtypeLikeOp<AtenToDtypeLayoutOp>(toDtypeLayout, operands);
return;
}

if (auto toDtype = dyn_cast<AtenToDeviceOp>(op)) {
visitAtenToDtypeLikeOp<AtenToDeviceOp>(toDtype, operands);
return;
}

if (auto toOther = dyn_cast<AtenToOtherOp>(op)) {
visitTypeConversionOp<AtenToOtherOp>(toOther, operands);
return;
} else if (auto typeAs = dyn_cast<AtenTypeAsOp>(op)) {
visitTypeConversionOp<AtenTypeAsOp>(typeAs, operands);
return;
}

if (auto cat = dyn_cast<AtenCatOp>(op)) {
visitAtenCatLikeOp<AtenCatOp>(cat, operands);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2887,6 +2887,45 @@ def aten〇_to_copy〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[in
self_rank, self_dtype = self_rank_dtype
return self_dtype if dtype is None else dtype

@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64))
def aten〇to〇dtype〡dtype(self_rank_dtype: Tuple[int, int], dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int:
return dtype

@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64))
def nvprims〇convert_element_type〡dtype(a_rank_dtype: Tuple[int, int], dtype: int) -> int:
return dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.complex64))
def aten〇to〇dtype_layout〡dtype(self_rank_dtype: Tuple[int, int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype if dtype is None else dtype

@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.float16) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.int32) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, device="meta", dtype=torch.complex64))
def aten〇to〇device〡dtype(self_rank_dtype: Tuple[int, int], device: device, dtype: int, non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int:
return dtype

@check_dtype_function(_check_two_tensor_op())
def aten〇to〇other〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], non_blocking: bool = False, copy: bool = False, memory_format: Optional[int] = None) -> int:
other_rank, other_dtype = other_rank_dtype
return other_dtype

@check_dtype_function(_check_two_tensor_op())
def aten〇type_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
other_rank, other_dtype = other_rank_dtype
return other_dtype

# ==============================================================================
# Main
# ==============================================================================
Expand Down
29 changes: 0 additions & 29 deletions test/Dialect/Torch/refine-types-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,6 @@ func.func @torch.aten.linear(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vte
return %1 : !torch.vtensor
}

// -----
// CHECK-LABEL: func.func @torch.aten.type_as(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>,
// CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten.type_as %[[INPUT]], %[[OTHER]] : !torch.tensor<[?],si64>, !torch.tensor<[?,2],f32> -> !torch.tensor<*,f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<*,f32> to !torch.tensor
// CHECK: return %[[CAST]] : !torch.tensor
func.func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch.tensor<[?,2],f32>) -> !torch.tensor {
%ret = torch.aten.type_as %self, %other : !torch.tensor<[?], si64>, !torch.tensor<[?,2],f32> -> !torch.tensor
return %ret: !torch.tensor
}

// -----
// CHECK-LABEL: func.func @torch.aten.cat(
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
Expand Down Expand Up @@ -126,23 +114,6 @@ func.func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>,
return %ret : !torch.tensor
}

// -----
// CHECK-LABEL: func.func @torch.aten.to.dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype
// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} :
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
// CHECK-SAME: -> !torch.tensor<*,si64>
// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<*,si64> to !torch.tensor
// CHECK-NEXT: return %[[RES]] : !torch.tensor
func.func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
%none = torch.constant.none
%false = torch.constant.bool false
%int4 = torch.constant.int 4
%0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor
return %0 : !torch.tensor
}

// -----
// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar(
// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {
Expand Down

0 comments on commit b20db8d

Please sign in to comment.