diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9fb1a9dcfec7..0417048ffdda 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -882,102 +882,6 @@ def Torch_AtenBitwiseNot_Op : Torch_Op<"aten.bitwise_not_", [ }]; } -def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other, - AnyTorchScalarType:$alpha - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenSubTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenSubTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - -def Torch_AtenSub_TensorOp : Torch_Op<"aten.sub_.Tensor", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::sub_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other, - AnyTorchScalarType:$alpha - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenSub_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenSub_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - -def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenMulTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenMulTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenMul_TensorOp : Torch_Op<"aten.mul_.Tensor", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::mul_.Tensor : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenMul_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenMul_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - def Torch_AtenDivTensorOp : Torch_Op<"aten.div.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -1072,55 +976,6 @@ def Torch_AtenLogicalOr_Op : Torch_Op<"aten.logical_or_", [ }]; } -def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other, - AnyTorchOptionalStringType:$rounding_mode - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenDivTensorModeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenDivTensorModeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - -def Torch_AtenDiv_TensorModeOp : Torch_Op<"aten.div_.Tensor_mode", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::div_.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other, - AnyTorchOptionalStringType:$rounding_mode - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenDiv_TensorModeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenDiv_TensorModeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - def Torch_AtenLerpTensorOp : Torch_Op<"aten.lerp.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -1358,110 +1213,106 @@ def Torch_AtenNe_TensorOp : Torch_Op<"aten.ne_.Tensor", [ }]; } -def Torch_AtenAddScalarOp : Torch_Op<"aten.add.Scalar", [ +def Torch_AtenDivScalarOp : Torch_Op<"aten.div.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::div.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAddScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenDivScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAddScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenDivScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenAdd_ScalarOp : Torch_Op<"aten.add_.Scalar", [ +def Torch_AtenDiv_ScalarOp : Torch_Op<"aten.div_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::add_.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::div_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAdd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenDiv_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAdd_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenDiv_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenSubScalarOp : Torch_Op<"aten.sub.Scalar", [ +def Torch_AtenNeScalarOp : Torch_Op<"aten.ne.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSubScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenNeScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenSubScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenNeScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenSub_ScalarOp : Torch_Op<"aten.sub_.Scalar", [ +def Torch_AtenNe_ScalarOp : Torch_Op<"aten.ne_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::sub_.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::ne_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSub_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenNe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenSub_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenNe_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenMulScalarOp : Torch_Op<"aten.mul.Scalar", [ +def Torch_AtenEqScalarOp : Torch_Op<"aten.eq.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchScalarType:$other @@ -1471,20 +1322,20 @@ def Torch_AtenMulScalarOp : Torch_Op<"aten.mul.Scalar", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMulScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenEqScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMulScalarOp::print(OpAsmPrinter &printer) { + void AtenEqScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ +def Torch_AtenEq_ScalarOp : Torch_Op<"aten.eq_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::mul_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::eq_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchScalarType:$other @@ -1494,21 +1345,21 @@ def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMul_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenEq_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMul_ScalarOp::print(OpAsmPrinter &printer) { + void AtenEq_ScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenDivScalarOp : Torch_Op<"aten.div.Scalar", [ +def Torch_AtenGtScalarOp : Torch_Op<"aten.gt.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::div.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchScalarType:$other @@ -1518,20 +1369,20 @@ def Torch_AtenDivScalarOp : Torch_Op<"aten.div.Scalar", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenDivScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenGtScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenDivScalarOp::print(OpAsmPrinter &printer) { + void AtenGtScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenDiv_ScalarOp : Torch_Op<"aten.div_.Scalar", [ +def Torch_AtenGt_ScalarOp : Torch_Op<"aten.gt_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::div_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::gt_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchScalarType:$other @@ -1541,21 +1392,21 @@ def Torch_AtenDiv_ScalarOp : Torch_Op<"aten.div_.Scalar", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenDiv_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenGt_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenDiv_ScalarOp::print(OpAsmPrinter &printer) { + void AtenGt_ScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenNeScalarOp : Torch_Op<"aten.ne.Scalar", [ +def Torch_AtenGeScalarOp : Torch_Op<"aten.ge.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchScalarType:$other @@ -1565,20 +1416,20 @@ def Torch_AtenNeScalarOp : Torch_Op<"aten.ne.Scalar", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNeScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenGeScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenNeScalarOp::print(OpAsmPrinter &printer) { + void AtenGeScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenNe_ScalarOp : Torch_Op<"aten.ne_.Scalar", [ +def Torch_AtenGe_ScalarOp : Torch_Op<"aten.ge_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::ne_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::ge_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchScalarType:$other @@ -1588,157 +1439,16 @@ def Torch_AtenNe_ScalarOp : Torch_Op<"aten.ne_.Scalar", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenGe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenNe_ScalarOp::print(OpAsmPrinter &printer) { + void AtenGe_ScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenEqScalarOp : Torch_Op<"aten.eq.Scalar", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenEqScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenEqScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenEq_ScalarOp : Torch_Op<"aten.eq_.Scalar", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::eq_.Scalar : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenEq_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenEq_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenGtScalarOp : Torch_Op<"aten.gt.Scalar", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenGtScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenGtScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenGt_ScalarOp : Torch_Op<"aten.gt_.Scalar", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::gt_.Scalar : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenGt_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenGt_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenGeScalarOp : Torch_Op<"aten.ge.Scalar", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenGeScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenGeScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenGe_ScalarOp : Torch_Op<"aten.ge_.Scalar", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::ge_.Scalar : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenGe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenGe_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenLtScalarOp : Torch_Op<"aten.lt.Scalar", [ +def Torch_AtenLtScalarOp : Torch_Op<"aten.lt.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly @@ -2622,6 +2332,104 @@ def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [ }]; } +def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other, + AnyTorchOptionalStringType:$rounding_mode + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDivTensorModeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenDivTensorModeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenDiv_TensorModeOp : Torch_Op<"aten.div_.Tensor_mode", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::div_.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other, + AnyTorchOptionalStringType:$rounding_mode + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDiv_TensorModeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenDiv_TensorModeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMulTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMulTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenMul_TensorOp : Torch_Op<"aten.mul_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::mul_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMul_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMul_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -2672,6 +2480,204 @@ def Torch_AtenAdd_TensorOp : Torch_Op<"aten.add_.Tensor", [ }]; } +def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSubTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenSubTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenSub_TensorOp : Torch_Op<"aten.sub_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::sub_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSub_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenSub_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenAddScalarOp : Torch_Op<"aten.add.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAddScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenAddScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenAdd_ScalarOp : Torch_Op<"aten.add_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::add_.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAdd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenAdd_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenSubScalarOp : Torch_Op<"aten.sub.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSubScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenSubScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenSub_ScalarOp : Torch_Op<"aten.sub_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::sub_.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSub_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenSub_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenMulScalarOp : Torch_Op<"aten.mul.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMulScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMulScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::mul_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMul_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMul_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 4fe6fb9fac70..40206ac82409 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -129,6 +129,10 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) { static Value getScalarValue(Value input, Location loc, PatternRewriter &rewriter) { + auto inputType = input.getType(); + if (inputType.isa()) { + return input; + } Value scalar = nullptr; if (auto valueTensorLiteralOp = input.getDefiningOp()) { if (valueTensorLiteralOp && @@ -289,8 +293,9 @@ LogicalResult ClassTypeOp::verify() { // PrimLoopOp //===----------------------------------------------------------------------===// -OperandRange PrimLoopOp::getSuccessorEntryOperands(Optional index) { - assert(index.has_value() && index.value() == 0); +OperandRange +PrimLoopOp::getSuccessorEntryOperands(Optional index) { + assert(index.hasValue() && index.value() == 0); return iterArgsInit(); } @@ -509,8 +514,8 @@ OpFoldResult DerefineOp::fold(ArrayRef operands) { return nullptr; } -void DerefineOp::getCanonicalizationPatterns( - RewritePatternSet &patterns, MLIRContext *context) { +void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { patterns.add(+[](DerefineOp op, PatternRewriter &rewriter) { bool madeChange = false; for (OpOperand &use : llvm::make_early_inc_range(op->getUses())) { @@ -821,41 +826,163 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, //===----------------------------------------------------------------------===// OpFoldResult AtenLenStrOp::fold(ArrayRef operands) { - if(auto stringConstruct = s().getDefiningOp()) - return getI64IntegerAttr(getContext(), stringConstruct.valueAttr().getValue().size()); + if (auto stringConstruct = s().getDefiningOp()) + return getI64IntegerAttr(getContext(), + stringConstruct.valueAttr().getValue().size()); return nullptr; } +LogicalResult rewrite0DBinaryTensorOp(Operation *op, + PatternRewriter &rewriter) { + Location loc = op->getLoc(); + // This canonicalization pattern also includes aten div/mul/add/sub ops + // between tensor and scalar, like aten.add.Scalar op + if (op->getNumOperands() < 2) { + return failure(); + } + auto lhs = getScalarValue(op->getOperand(0), loc, rewriter); + auto rhs = getScalarValue(op->getOperand(1), loc, rewriter); + auto outType = op->getResult(0).getType(); + + if (!lhs || !rhs) { + return rewriter.notifyMatchFailure( + op, "only int scalar lhs or rhs is supported"); + } + if (isa( + op)) { + Value alpha = getScalarValue(op->getOperand(2), loc, rewriter); + if (!alpha) { + return rewriter.notifyMatchFailure(op, + "only int scalar alpha is supported"); + } + rhs = rewriter.create(loc, rhs, alpha); + } + + if (isa(op)) { + // None rounding mode + if (op->getOperand(2).getType().isa()) { + Value quotient = rewriter.create(loc, lhs, rhs); + rewriter.replaceOpWithNewOp(op, outType, + quotient); + return success(); + } + std::string roundingMode; + if (!matchPattern(op->getOperand(2), m_TorchConstantStr(roundingMode))) { + return rewriter.notifyMatchFailure( + op, "only None, 'floor' or 'trunc' rounding mode is supported"); + } + if (roundingMode == "floor") { + Value quotient = rewriter.create(loc, lhs, rhs); + rewriter.replaceOpWithNewOp(op, outType, + quotient); + return success(); + } + // For "trunc" rounding mode, insted of canonicalizing it into + // aten.abs, aten.floor, aten.sign and aten.mul.int ops, which adds + // complexity but helps little in optimization (such as constant folding), + // we are trying to fold it. + if (roundingMode == "trunc") { + int64_t lhsInt; + int64_t rhsInt; + if (!matchPattern(lhs, m_TorchConstantInt(&lhsInt))) { + return failure(); + } + if (!matchPattern(rhs, m_TorchConstantInt(&rhsInt))) { + return failure(); + } + + int64_t result = (int64_t)std::trunc((double)lhsInt / rhsInt); + Value resultScalar = rewriter.create( + loc, rewriter.getI64IntegerAttr(result)); + rewriter.replaceOpWithNewOp(op, outType, + resultScalar); + return success(); + } + + return failure(); + } + + Value result; + // Other Add/Sub/Mul ops + if (isa(op)) { + result = rewriter.create(loc, lhs, rhs); + } else if (isa(op)) { + result = rewriter.create(loc, lhs, rhs); + } else if (isa(op)) { + result = rewriter.create(loc, lhs, rhs); + } + rewriter.replaceOpWithNewOp(op, outType, result); + return success(); +} + //===----------------------------------------------------------------------===// // AtenAddTensorOp //===----------------------------------------------------------------------===// - void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenAddTensorOp op, PatternRewriter &rewriter) { - // The lhs and rhs of the add.tensor op should be 0d tensors for the - // canonicalization to be carried out. - // `aten.add.tensor(self, other, alpha)` is canonicalized to - // `aten.add.int(self, aten.mul.int(other, alpha))`. - - Value lhs = getScalarValue(op.self(), op.getLoc(), rewriter); - if (!lhs) - return rewriter.notifyMatchFailure(op, "lhs scalar is empyty"); - if (!lhs.getType().isa()) - return rewriter.notifyMatchFailure(op, "lhs scalar is not IntType"); - - Value rhs = getScalarValue(op.other(), op.getLoc(), rewriter); - if (!rhs) - return rewriter.notifyMatchFailure(op, "rhs scalar is empyty"); - if (!rhs.getType().isa()) - return rewriter.notifyMatchFailure(op, "rhs scalar is not IntType"); - - Value mul = rewriter.create(op->getLoc(), rhs, op.alpha()); - Value add = rewriter.create(op->getLoc(), lhs, mul); - rewriter.replaceOpWithNewOp( - op, op.self().getType(), add); - return success(); + return rewrite0DBinaryTensorOp(op, rewriter); + }); +} + +//===----------------------------------------------------------------------===// +// AtenAddScalarOp +//===----------------------------------------------------------------------===// +void AtenAddScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenAddScalarOp op, PatternRewriter &rewriter) { + return rewrite0DBinaryTensorOp(op, rewriter); + }); +} + +//===----------------------------------------------------------------------===// +// AtenSubTensorOp +//===----------------------------------------------------------------------===// +void AtenSubTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenSubTensorOp op, PatternRewriter &rewriter) { + return rewrite0DBinaryTensorOp(op, rewriter); + }); +} + +//===----------------------------------------------------------------------===// +// AtenSubScalarOp +//===----------------------------------------------------------------------===// +void AtenSubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenSubScalarOp op, PatternRewriter &rewriter) { + return rewrite0DBinaryTensorOp(op, rewriter); + }); +} + +//===----------------------------------------------------------------------===// +// AtenMulTensorOp +//===----------------------------------------------------------------------===// +void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenMulTensorOp op, PatternRewriter &rewriter) { + return rewrite0DBinaryTensorOp(op, rewriter); + }); +} + +//===----------------------------------------------------------------------===// +// AtenMulScalarOp +//===----------------------------------------------------------------------===// +void AtenMulScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenMulScalarOp op, PatternRewriter &rewriter) { + return rewrite0DBinaryTensorOp(op, rewriter); + }); +} + +//===----------------------------------------------------------------------===// +// AtenDivTensorModeOp +//===----------------------------------------------------------------------===// +void AtenDivTensorModeOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenDivTensorModeOp op, PatternRewriter &rewriter) { + return rewrite0DBinaryTensorOp(op, rewriter); }); } @@ -1719,8 +1846,8 @@ OpFoldResult Aten__Contains__StrOp::fold(ArrayRef operands) { static bool isListConstructNotModified(Value torchList) { return llvm::all_of(torchList.getUsers(), [](Operation *op) { - return isa(op); - }); + return isa(op); + }); } OpFoldResult Aten__Contains__IntListOp::fold(ArrayRef operands) { @@ -2074,7 +2201,6 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() { if (walkResult.wasInterrupted()) return failure(); - return success(); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 973ee2a9fbec..9c84728f30d0 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -257,19 +257,13 @@ def emit_with_mutating_variants(key, **kwargs): "aten::floor : (Tensor) -> (Tensor)", "aten::ceil : (Tensor) -> (Tensor)", "aten::bitwise_not : (Tensor) -> (Tensor)", - "aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", - "aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::logical_or : (Tensor, Tensor) -> (Tensor)", - "aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", "aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::lt.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)", - "aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", - "aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", - "aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::div.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", @@ -298,7 +292,14 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants(key) # Elementwise tensor compute ops that don't have the standard mutating # variants. - emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::maximum : (Tensor, Tensor) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index d5cfb0e36ca0..36aa26f1058a 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1414,3 +1414,210 @@ func.func @prim.ListUnpack$fold_list(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !t %1:2 = torch.prim.ListUnpack %0 : !torch.list -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> return %1#0, %1#1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> } + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { + %int6 = torch.constant.int 6 + %str = torch.constant.str "floor" + %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> + %1 = torch.prim.NumToTensor.Scalar %int6 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { + %int6 = torch.constant.int 6 + %int2 = torch.constant.int 2 + %str = torch.constant.str "floor" + %0 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64> + %1 = torch.prim.NumToTensor.Scalar %int6 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.add.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %2 = torch.aten.add.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT_6:.*]] = torch.constant.int -6 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64> + %1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.sub.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT_6:.*]] = torch.constant.int -6 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +func.func @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %1 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.sub.Tensor %0, %1, %int3 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT_6:.*]] = torch.constant.int -6 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.sub.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT_6:.*]] = torch.constant.int -6 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %2 = torch.aten.sub.Scalar %0, %int2, %int3 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT6]] = torch.constant.int 6 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { + %int3 = torch.constant.int 3 + %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> + %2 = torch.aten.mul.Scalar %0, %int3 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.mul.Scalar %0, %int3 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT6]] = torch.constant.int 6 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { + %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> + %1 = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> + %2 = torch.aten.mul.Tensor %0, %1 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64> + %1 = torch.prim.NumToTensor.Scalar %int3 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.mul.Tensor %0, %1 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !torch.vtensor<[],si64> { + %int6 = torch.constant.int 6 + %int2 = torch.constant.int 2 + %str = torch.constant.str "trunc" + %0 = torch.prim.NumToTensor.Scalar %int2 : !torch.int -> !torch.vtensor<[],si64> + %1 = torch.prim.NumToTensor.Scalar %int6 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.vtensor<[],si64> { +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> +// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.vtensor<[],si64> { + %int6 = torch.constant.int 6 + %str = torch.constant.str "trunc" + %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> + %1 = torch.prim.NumToTensor.Scalar %int6 : !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.div.Tensor_mode %1, %0, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64> + return %2 : !torch.vtensor<[],si64> +} \ No newline at end of file