Skip to content

Commit

Permalink
Replace AtenAddScalarOp with AtenRsubScalarOp
Browse files Browse the repository at this point in the history
  • Loading branch information
ramiro050 committed Dec 12, 2022
1 parent b9f941d commit 634dd9a
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 52 deletions.
59 changes: 17 additions & 42 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5641,10 +5641,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.prims.convert_element_type\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.to.dtype_layout\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.bool, %arg7: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
Expand Down Expand Up @@ -5712,23 +5708,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0 = torch.prim.ListConstruct %arg0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = torch.prim.ListConstruct %arg1, %1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list<optional<int>>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union<float, int>) -> !torch.int {\n"
" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union<float, int> -> !torch.tensor\n"
" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.sub.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -5757,6 +5736,23 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rsub.Scalar\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.union<float, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0 = torch.prim.ListConstruct %arg0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %1 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.union<float, int>) -> !torch.int\n"
" %2 = torch.prim.ListConstruct %arg1, %1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %3 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%0, %2) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %3 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%arg0: !torch.list<optional<int>>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0 = torch.promote_dtypes %arg0, %arg1 : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0: !torch.union<float, int>) -> !torch.int {\n"
" %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.union<float, int> -> !torch.tensor\n"
" %1 = torch.prim.dtype %0 : !torch.tensor -> !torch.int\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.leaky_relu\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -5884,13 +5880,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %1 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.amax\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %0 = torch.derefine %arg1 : !torch.list<int> to !torch.optional<list<int>>\n"
" %1 = torch.derefine %none : !torch.none to !torch.any\n"
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any\n"
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
Expand Down Expand Up @@ -6390,12 +6379,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.randn\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.randn.generator\"(%arg0: !torch.list<int>, %arg1: !torch.any, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.arange.start_step\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg0 : !torch.float to !torch.union<float, int>\n"
" %1 = torch.derefine %arg1 : !torch.float to !torch.union<float, int>\n"
Expand Down Expand Up @@ -6504,18 +6487,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.ge.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.lt.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.le.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.unsqueeze\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -779,8 +779,8 @@ void TypeAnalysis::visitOperation(Operation *op,
}

// Promote LHS with scalar RHS.
if (isa<AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp, AtenFmodScalarOp,
AtenFloorDivideScalarOp, AtenPowTensorScalarOp, AtenRsubScalarOp,
if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp,
AtenFmodScalarOp, AtenFloorDivideScalarOp, AtenPowTensorScalarOp,
AtenLeakyReluOp, AtenRemainderScalarOp>(op)) {
auto lhs = operands[0]->getValue();
Value scalar = op->getOperand(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,6 @@ def aten〇lt〇Scalar〡shape(self: List[int], other: float) -> List[int]:
def aten〇add〇Scalar〡shape(self: List[int], other: float, alpha: float = 1) -> List[int]:
return upstream_shape_functions.unary(self)

@check_dtype_function([
Invocation(TensorOfShape(1, 1, 1, dtype=torch.float32), other=0),
Invocation(TensorOfShape(1, 1, 1, dtype=torch.int64), other=0.0),
Invocation(TensorOfShape(1, 1, 1, dtype=torch.float16), other=0.0)
])
def aten〇add〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int:
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])

def aten〇sub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1) -> List[int]:
return upstream_shape_functions.unary(self)

Expand All @@ -259,6 +251,14 @@ def aten〇pow〇Tensor_Tensor〡shape(self: List[int], exponent: List[int]) ->
def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1) -> List[int]:
return upstream_shape_functions.unary(self)

@check_dtype_function([
Invocation(TensorOfShape(1, 1, 1, dtype=torch.float32), other=0),
Invocation(TensorOfShape(1, 1, 1, dtype=torch.int64), other=0.0),
Invocation(TensorOfShape(1, 1, 1, dtype=torch.float16), other=0.0)
])
def aten〇rsub〇Scalar〡dtype(self_rank: int, self_dtype: int, other: Union[int, float], alpha: Union[int, float] = 1) -> int:
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])

def aten〇leaky_relu〡shape(self: List[int], negative_slope: float = 0.01) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down

0 comments on commit 634dd9a

Please sign in to comment.