diff --git a/externals/llvm-project b/externals/llvm-project index a3f2751f782f..4acc3ffbb0af 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit a3f2751f782f3cdc6ba4790488ec20163a40ac37 +Subproject commit 4acc3ffbb0af5631bc7916aeff3570f448899647 diff --git a/externals/mlir-hlo b/externals/mlir-hlo index 97c7e4b4506c..16886a108eff 160000 --- a/externals/mlir-hlo +++ b/externals/mlir-hlo @@ -1 +1 @@ -Subproject commit 97c7e4b4506c3a2441c923e592833f45da439009 +Subproject commit 16886a108eff5197f816ca0f1950cc5ff1b078d9 diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index 7783b26abf08..64b70e097c39 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -10,6 +10,7 @@ #ifndef TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H #define TORCHMLIR_DIALECT_TORCH_IR_TORCHOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index bf2f20d8202a..31e8292452a9 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2346,7 +2346,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op.getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(toReduceShape), inputType.getElementType()), - sumDiv, rewriter.getI64IntegerAttr(i)); + sumDiv, rewriter.getI32IntegerAttr(i)); } return rewriter.create( @@ -3214,7 +3214,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( prunedShape.push_back(en.value()); } - auto dimAttr = rewriter.getIntegerAttr(rewriter.getI64Type(), dim); + auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim); auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape); Value reduceMax = rewriter.create( @@ -4787,7 +4787,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType); auto result = tosa::CreateOpAndInfer( - rewriter, loc, outType, builtinTensors, rewriter.getI64IntegerAttr(dim)); + rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim)); rewriter.replaceOp(op, result.getResult()); return success(); } diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 0abe8ac02b53..287a8943594a 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -382,7 +382,7 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixReducesumShape, indicesType.getElementType()), - flattenedIndicesMulOp.getResult(), rewriter.getI64IntegerAttr(1)); + flattenedIndicesMulOp.getResult(), rewriter.getI32IntegerAttr(1)); // And reshape to [N, W] // %7 = "tosa.reshape"(%6) {new_shape = [1, 8]} : (tensor<8x1xi32>) -> @@ -648,7 +648,7 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixReducesumShape, indicesType.getElementType()), - flattenedIndicesMulOp.getResult(), rewriter.getI64IntegerAttr(1)); + flattenedIndicesMulOp.getResult(), rewriter.getI32IntegerAttr(1)); // And reshape to [N, W] // [[1],[2],[3]] -> [[1,2,3]] @@ -717,7 +717,7 @@ std::optional convertReduceOpCommon( int64_t axis_val = axes_elems.getValues()[i].getInt(); if (axis_val < 0) axis_val += input_rank; - auto axis_attr = rewriter.getI64IntegerAttr(axis_val); + auto axis_attr = rewriter.getI32IntegerAttr(axis_val); shape_vec[axis_val] = 1; RankedTensorType reduce_type = RankedTensorType::get( diff --git a/lib/Dialect/Torch/IR/CMakeLists.txt b/lib/Dialect/Torch/IR/CMakeLists.txt index cf54afe06c2e..00210e4fd379 100644 --- a/lib/Dialect/Torch/IR/CMakeLists.txt +++ b/lib/Dialect/Torch/IR/CMakeLists.txt @@ -16,6 +16,9 @@ add_mlir_library(TorchMLIRTorchDialect Core LINK_LIBS PUBLIC + MLIRBytecodeOpInterface + MLIRBytecodeReader + MLIRBytecodeWriter MLIRFuncDialect MLIRIR MLIRSupport diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index eca213f8dc86..aed453a62da4 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -301,21 +301,20 @@ LogicalResult ClassTypeOp::verify() { // PrimLoopOp //===----------------------------------------------------------------------===// -OperandRange -PrimLoopOp::getEntrySuccessorOperands(std::optional index) { - assert(index.has_value() && index.value() == 0); +OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert(point == getRegion()); return getIterArgsInit(); } void PrimLoopOp::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { - - if (!index.has_value()) { - regions.emplace_back(&getRegion(), getRegion().getArguments().slice(1)); + RegionBranchPoint point, SmallVectorImpl ®ions) { + Region ®ion = getRegion(); + if (!point.getRegionOrNull()) { + regions.emplace_back(®ion, region.getArguments().slice(1)); return; } - assert(*index == 0); - regions.emplace_back(&getRegion(), getRegion().getArguments().slice(1)); + assert(point == region); + regions.emplace_back(®ion, region.getArguments().slice(1)); regions.emplace_back(getResults()); } @@ -328,8 +327,8 @@ bool PrimLoopOp::isForLike() { // PrimLoopConditionOp //===----------------------------------------------------------------------===// -MutableOperandRange PrimLoopConditionOp::getMutableSuccessorOperands( - std::optional index) { +MutableOperandRange +PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { // Pass all operands except the condition to the successor which is the // parent loop op. return getIterArgsMutable(); @@ -378,10 +377,10 @@ void PrimIfOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); } -void PrimIfOp::getSuccessorRegions(std::optional index, +void PrimIfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The `then` and the `else` region branch back to the parent operation. - if (index.has_value()) { + if (point.getRegionOrNull()) { regions.push_back(RegionSuccessor(getResults())); return; } @@ -1595,7 +1594,9 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - auto attr = attributes.get("value").dyn_cast_or_null(); + auto attr = properties.as() + ->getValue() + .dyn_cast_or_null(); if (!attr) return failure(); RankedTensorType tensorType = attr.getType().cast(); @@ -1635,7 +1636,9 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - auto attr = attributes.get("value").dyn_cast_or_null(); + auto attr = properties.as() + ->getValue() + .dyn_cast_or_null(); if (!attr) return failure(); RankedTensorType tensorType = attr.getType().cast(); @@ -2768,26 +2771,26 @@ OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) { template static void -getSuccessorRegionsForCalculateOp(CalculateOp op, std::optional index, +getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point, SmallVectorImpl ®ions) { - if (!index.has_value()) { + if (!point.getRegionOrNull()) { // First thing the op does is branch into the calculation. regions.emplace_back(&op.getCalculation()); return; } - if (*index == 0) { + if (point == op.getBody()) { // Body returns control to the outer op, passing through results. regions.emplace_back(op.getResults()); return; } - assert(*index == 1); + assert(point == op.getCalculation()); // Calculation branches to the body. regions.emplace_back(&op.getBody()); } void ShapeCalculateOp::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { - getSuccessorRegionsForCalculateOp(*this, index, regions); + RegionBranchPoint point, SmallVectorImpl ®ions) { + getSuccessorRegionsForCalculateOp(*this, point, regions); } //===----------------------------------------------------------------------===// @@ -2795,8 +2798,8 @@ void ShapeCalculateOp::getSuccessorRegions( //===----------------------------------------------------------------------===// void DtypeCalculateOp::getSuccessorRegions( - std::optional index, SmallVectorImpl ®ions) { - getSuccessorRegionsForCalculateOp(*this, index, regions); + RegionBranchPoint point, SmallVectorImpl ®ions) { + getSuccessorRegionsForCalculateOp(*this, point, regions); } //===----------------------------------------------------------------------===// @@ -2804,7 +2807,7 @@ void DtypeCalculateOp::getSuccessorRegions( //===----------------------------------------------------------------------===// MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands( - std::optional index) { + RegionBranchPoint point) { // The shape operands don't get forwarded to the body. // MutableOperandRange always has an owning operation, even if empty, so // create a 0-length range. @@ -2823,7 +2826,7 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() { //===----------------------------------------------------------------------===// MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands( - std::optional index) { + RegionBranchPoint point) { // The dtype operands don't get forwarded to the body. // MutableOperandRange always has an owning operation, even if empty, so // create a 0-length range. diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index df8c148902b9..49907a98a56d 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func.func @torch.aten.tanh$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.tanh"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.tanh %[[ARG_BUILTIN]] : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -16,7 +16,7 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK-LABEL: func.func @torch.aten.sigmoid$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.sigmoid"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.sigmoid %[[ARG_BUILTIN]] : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -29,7 +29,7 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch. // CHECK-LABEL: func.func @torch.aten.relu$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.clamp"(%[[ARG_BUILTIN]]) <{max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.clamp %[[ARG_BUILTIN]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -46,9 +46,9 @@ func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK: %[[VAL_1:.*]] = torch.constant.float 1.000000e-01 // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e-01> : tensor}> : () -> tensor // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.greater_equal"(%[[VAL_0]], %[[VAL_3]]) : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_2]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.select"(%[[VAL_4]], %[[VAL_0]], %[[VAL_5]]) : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_0]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_0]], %[[VAL_2]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_4]], %[[VAL_0]], %[[VAL_5]] : (tensor, tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -64,7 +64,7 @@ func.func @torch.aten.leaky_relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK-LABEL: func.func @torch.aten.log$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.log"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.log %[[ARG_BUILTIN]] : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -77,7 +77,7 @@ func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.exp$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.exp"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.exp %[[ARG_BUILTIN]] : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -90,7 +90,7 @@ func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.neg$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.negate"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.negate %[[ARG_BUILTIN]] : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -103,7 +103,7 @@ func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.floor$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.floor"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.floor %[[ARG_BUILTIN]] : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.floor$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -116,7 +116,7 @@ func.func @torch.aten.floor$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt // CHECK-LABEL: func.func @torch.aten.bitwise_not$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.bitwise_not"(%[[ARG_BUILTIN]]) : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.bitwise_not %[[ARG_BUILTIN]] : (tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -129,7 +129,7 @@ func.func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // CHECK-LABEL: func.func @torch.aten.ceil$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = "tosa.ceil"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.ceil %[[VAL_1]] : (tensor) -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -143,7 +143,7 @@ func.func @torch.aten.ceil$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK-LABEL: func.func @torch.aten.reciprocal$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.reciprocal %[[VAL_1]] : (tensor) -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -161,8 +161,8 @@ func.func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_6]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.add %[[VAL_2]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -181,8 +181,8 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_2]], %[[VAL_6]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_2]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -199,7 +199,7 @@ func.func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.mul %[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]] {shift = 0 : i32} : (tensor, tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { @@ -214,8 +214,8 @@ func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RCP:.*]] = "tosa.reciprocal"(%[[ARG1_BUILTIN]]) : (tensor) -> tensor -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[ARG0_BUILTIN]], %[[RCP]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor +// CHECK: %[[RCP:.*]] = tosa.reciprocal %[[ARG1_BUILTIN]] : (tensor) -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.mul %[[ARG0_BUILTIN]], %[[RCP]] {shift = 0 : i32} : (tensor, tensor) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { @@ -244,8 +244,8 @@ func.func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK: %[[ARG2_BUILTIN:.*]] = torch.constant.bool false // CHECK: %[[ARG3:.*]] = torch.constant.int 0 // CHECK: %[[ARG3_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG3]] : (!torch.int) -> !torch.list -// CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) <{axis = 0 : i64}> : (tensor) -> tensor<1x?x?x?xf32> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) <{new_shape = array}> : (tensor<1x?x?x?xf32>) -> tensor +// CHECK: %[[SUM:.*]] = tosa.reduce_sum %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xf32> +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[SUM]] {new_shape = array} : (tensor<1x?x?x?xf32>) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32> func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -263,11 +263,11 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> { // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[ARG1_BUILTIN:.*]] = torch.constant.none -// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) <{axis = 0 : i64}> : (tensor) -> tensor<1x?x?x?xf32> -// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_sum"(%[[REDUCE1]]) <{axis = 1 : i64}> : (tensor<1x?x?x?xf32>) -> tensor<1x1x?x?xf32> -// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_sum"(%[[REDUCE2]]) <{axis = 2 : i64}> : (tensor<1x1x?x?xf32>) -> tensor<1x1x1x?xf32> -// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_sum"(%[[REDUCE3]]) <{axis = 3 : i64}> : (tensor<1x1x1x?xf32>) -> tensor<1x1x1x1xf32> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) <{new_shape = array}> : (tensor<1x1x1x1xf32>) -> tensor<1xf32> +// CHECK: %[[REDUCE1:.*]] = tosa.reduce_sum %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xf32> +// CHECK: %[[REDUCE2:.*]] = tosa.reduce_sum %[[REDUCE1]] {axis = 1 : i32} : (tensor<1x?x?x?xf32>) -> tensor<1x1x?x?xf32> +// CHECK: %[[REDUCE3:.*]] = tosa.reduce_sum %[[REDUCE2]] {axis = 2 : i32} : (tensor<1x1x?x?xf32>) -> tensor<1x1x1x?xf32> +// CHECK: %[[REDUCE4:.*]] = tosa.reduce_sum %[[REDUCE3]] {axis = 3 : i32} : (tensor<1x1x1x?xf32>) -> tensor<1x1x1x1xf32> +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[REDUCE4]] {new_shape = array} : (tensor<1x1x1x1xf32>) -> tensor<1xf32> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],f32> func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> { @@ -281,11 +281,11 @@ func.func @test_reduce_sum$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK-LABEL: func.func @test_reduce_all$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor -// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_all"(%[[ARG0_BUILTIN]]) <{axis = 0 : i64}> : (tensor) -> tensor<1x?x?x?xi1> -// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_all"(%[[REDUCE1]]) <{axis = 1 : i64}> : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> -// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_all"(%[[REDUCE2]]) <{axis = 2 : i64}> : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> -// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_all"(%[[REDUCE3]]) <{axis = 3 : i64}> : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) <{new_shape = array}> : (tensor<1x1x1x1xi1>) -> tensor<1xi1> +// CHECK: %[[REDUCE1:.*]] = tosa.reduce_all %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xi1> +// CHECK: %[[REDUCE2:.*]] = tosa.reduce_all %[[REDUCE1]] {axis = 1 : i32} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> +// CHECK: %[[REDUCE3:.*]] = tosa.reduce_all %[[REDUCE2]] {axis = 2 : i32} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> +// CHECK: %[[REDUCE4:.*]] = tosa.reduce_all %[[REDUCE3]] {axis = 3 : i32} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[REDUCE4]] {new_shape = array} : (tensor<1x1x1x1xi1>) -> tensor<1xi1> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1> func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { @@ -300,8 +300,8 @@ func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch. // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor // CHECK: %[[ARG1:.*]] = torch.constant.int 0 // CHECK: %[[ARG2:.*]] = torch.constant.bool false -// CHECK: %[[REDUCE:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) <{axis = 0 : i64}> : (tensor) -> tensor<1x?x?x?xi1> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) <{new_shape = array}> : (tensor<1x?x?x?xi1>) -> tensor +// CHECK: %[[REDUCE:.*]] = tosa.reduce_any %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xi1> +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[REDUCE]] {new_shape = array} : (tensor<1x?x?x?xi1>) -> tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?,?],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],i1> func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> { @@ -316,11 +316,11 @@ func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !to // CHECK-LABEL: func.func @test_reduce_any$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],i1> -> tensor -// CHECK: %[[REDUCE1:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) <{axis = 0 : i64}> : (tensor) -> tensor<1x?x?x?xi1> -// CHECK: %[[REDUCE2:.*]] = "tosa.reduce_any"(%[[REDUCE1]]) <{axis = 1 : i64}> : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> -// CHECK: %[[REDUCE3:.*]] = "tosa.reduce_any"(%[[REDUCE2]]) <{axis = 2 : i64}> : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> -// CHECK: %[[REDUCE4:.*]] = "tosa.reduce_any"(%[[REDUCE3]]) <{axis = 3 : i64}> : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> -// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE4]]) <{new_shape = array}> : (tensor<1x1x1x1xi1>) -> tensor<1xi1> +// CHECK: %[[REDUCE1:.*]] = tosa.reduce_any %[[ARG0_BUILTIN]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xi1> +// CHECK: %[[REDUCE2:.*]] = tosa.reduce_any %[[REDUCE1]] {axis = 1 : i32} : (tensor<1x?x?x?xi1>) -> tensor<1x1x?x?xi1> +// CHECK: %[[REDUCE3:.*]] = tosa.reduce_any %[[REDUCE2]] {axis = 2 : i32} : (tensor<1x1x?x?xi1>) -> tensor<1x1x1x?xi1> +// CHECK: %[[REDUCE4:.*]] = tosa.reduce_any %[[REDUCE3]] {axis = 3 : i32} : (tensor<1x1x1x?xi1>) -> tensor<1x1x1x1xi1> +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.reshape %[[REDUCE4]] {new_shape = array} : (tensor<1x1x1x1xi1>) -> tensor<1xi1> // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<1xi1> -> !torch.vtensor<[1],i1> // CHECK: return %[[RESULT]] : !torch.vtensor<[1],i1> func.func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[1],i1> { @@ -333,7 +333,7 @@ func.func @test_reduce_any$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch. // CHECK-LABEL: func.func @torch.aten.rsqrt$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = "tosa.rsqrt"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.rsqrt %[[VAL_1]] : (tensor) -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -349,7 +349,7 @@ func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.maximum"(%[[VAL_2]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.maximum %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -365,7 +365,7 @@ func.func @torch.aten.maximum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !to // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.minimum"(%[[VAL_2]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.minimum %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -381,7 +381,7 @@ func.func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !to // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.pow"(%[[VAL_1]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_1]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -400,8 +400,8 @@ func.func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) // CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<6.432100e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_1]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_4]], %[[VAL_6]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_4]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -421,8 +421,8 @@ func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_1]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_4]], %[[VAL_6]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_4]], %[[VAL_6]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -440,7 +440,7 @@ func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.greater"(%[[VAL_2]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.greater %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> // CHECK: } @@ -456,7 +456,7 @@ func.func @torch.aten.gt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.greater"(%[[VAL_3]], %[[VAL_2]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.greater %[[VAL_3]], %[[VAL_2]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> // CHECK: } @@ -472,7 +472,7 @@ func.func @torch.aten.lt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.equal"(%[[VAL_2]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.equal %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> // CHECK: } @@ -488,7 +488,7 @@ func.func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.int -1 // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list -// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32> // CHECK: } @@ -510,17 +510,17 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to // CHECK: %[[VAL_5:.*]] = torch.constant.float 1.000000e-05 // CHECK: %[[VAL_6:.*]] = torch.constant.bool true // CHECK: %[[VAL_7:.*]] = torch.constant.bool false -// CHECK: %[[VAL_8:.*]] = "tosa.reshape"(%[[VAL_2]]) <{new_shape = array}> : (tensor<4xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_9:.*]] = "tosa.reshape"(%[[VAL_3]]) <{new_shape = array}> : (tensor<4xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_10:.*]] = "tosa.reshape"(%[[VAL_3]]) <{new_shape = array}> : (tensor<4xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) <{new_shape = array}> : (tensor<4xf32>) -> tensor<4x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor -// CHECK: %[[VAL_13:.*]] = "tosa.sub"(%[[VAL_1]], %[[VAL_8]]) : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_14:.*]] = "tosa.add"(%[[VAL_9]], %[[VAL_12]]) : (tensor<4x1xf32>, tensor) -> tensor<4x1xf32> -// CHECK: %[[VAL_15:.*]] = "tosa.rsqrt"(%[[VAL_14]]) : (tensor<4x1xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_13]], %[[VAL_15]]) <{shift = 0 : i32}> : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_17:.*]] = "tosa.mul"(%[[VAL_16]], %[[VAL_10]]) <{shift = 0 : i32}> : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_18:.*]] = "tosa.add"(%[[VAL_17]], %[[VAL_11]]) : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_14:.*]] = tosa.add %[[VAL_9]], %[[VAL_12]] : (tensor<4x1xf32>, tensor) -> tensor<4x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.rsqrt %[[VAL_14]] : (tensor<4x1xf32>) -> tensor<4x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i32} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_10]] {shift = 0 : i32} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_18:.*]] = tosa.add %[[VAL_17]], %[[VAL_11]] : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32> // CHECK: return %[[VAL_19]] : !torch.vtensor<[10,4,3],f32> // CHECK: } @@ -542,7 +542,7 @@ func.func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32 // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,3,8,9,3,4],f32> -> tensor<10x3x8x9x3x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 4 // CHECK: %[[VAL_3:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor<10x3x8x9x3x4xf32>) -> tensor<10x3x216x4xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<10x3x8x9x3x4xf32>) -> tensor<10x3x216x4xf32> // CHECK: %[[VAL_5:.*]] = tensor.cast %[[VAL_4]] : tensor<10x3x216x4xf32> to tensor<10x3x?x4xf32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<10x3x?x4xf32> -> !torch.vtensor<[10,3,?,4],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[10,3,?,4],f32> @@ -568,28 +568,28 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor // CHECK: %[[VAL_8:.*]] = torch.constant.int 2 // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_8]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<1.200000e+01> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_11:.*]] = "tosa.reciprocal"(%[[VAL_10]]) : (tensor<1xf32>) -> tensor<1xf32> -// CHECK: %[[VAL_12:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) <{axis = 3 : i64}> : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> -// CHECK: %[[VAL_13:.*]] = "tosa.reduce_sum"(%[[VAL_12]]) <{axis = 2 : i64}> : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> -// CHECK: %[[VAL_14:.*]] = "tosa.reduce_sum"(%[[VAL_13]]) <{axis = 1 : i64}> : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_15:.*]] = "tosa.reshape"(%[[VAL_14]]) <{new_shape = array}> : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_15]], %[[VAL_11]]) <{shift = 0 : i32}> : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_17:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_18:.*]] = "tosa.mul"(%[[VAL_17]], %[[VAL_17]]) <{shift = 0 : i32}> : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_19:.*]] = "tosa.reduce_sum"(%[[VAL_18]]) <{axis = 3 : i64}> : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> -// CHECK: %[[VAL_20:.*]] = "tosa.reduce_sum"(%[[VAL_19]]) <{axis = 2 : i64}> : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> -// CHECK: %[[VAL_21:.*]] = "tosa.reduce_sum"(%[[VAL_20]]) <{axis = 1 : i64}> : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_22:.*]] = "tosa.reshape"(%[[VAL_21]]) <{new_shape = array}> : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_23:.*]] = "tosa.mul"(%[[VAL_22]], %[[VAL_11]]) <{shift = 0 : i32}> : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_24:.*]] = "tosa.reshape"(%[[VAL_4]]) <{new_shape = array}> : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> -// CHECK: %[[VAL_25:.*]] = "tosa.reshape"(%[[VAL_5]]) <{new_shape = array}> : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reciprocal %[[VAL_10]] : (tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_12]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], %[[VAL_11]] {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.sub %[[VAL_3]], %[[VAL_16]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_17]], %[[VAL_17]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_19:.*]] = tosa.reduce_sum %[[VAL_18]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reduce_sum %[[VAL_19]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_22:.*]] = tosa.reshape %[[VAL_21]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_22]], %[[VAL_11]] {shift = 0 : i32} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> +// CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> // CHECK: %[[VAL_26:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor -// CHECK: %[[VAL_27:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_16]]) : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_28:.*]] = "tosa.add"(%[[VAL_23]], %[[VAL_26]]) : (tensor<5x1x1x1xf32>, tensor) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_29:.*]] = "tosa.rsqrt"(%[[VAL_28]]) : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_30:.*]] = "tosa.mul"(%[[VAL_27]], %[[VAL_29]]) <{shift = 0 : i32}> : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_31:.*]] = "tosa.mul"(%[[VAL_30]], %[[VAL_24]]) <{shift = 0 : i32}> : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_32:.*]] = "tosa.add"(%[[VAL_31]], %[[VAL_25]]) : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_27:.*]] = tosa.sub %[[VAL_3]], %[[VAL_16]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_28:.*]] = tosa.add %[[VAL_23]], %[[VAL_26]] : (tensor<5x1x1x1xf32>, tensor) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_29:.*]] = tosa.rsqrt %[[VAL_28]] : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_30:.*]] = tosa.mul %[[VAL_27]], %[[VAL_29]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_31:.*]] = tosa.mul %[[VAL_30]], %[[VAL_24]] {shift = 0 : i32} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_32:.*]] = tosa.add %[[VAL_31]], %[[VAL_25]] : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_33:.*]] = torch_c.from_builtin_tensor %[[VAL_32]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32> // CHECK: return %[[VAL_33]] : !torch.vtensor<[5,2,2,3],f32> // CHECK: } @@ -609,8 +609,8 @@ func.func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor< // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.equal"(%[[VAL_2]], %[[VAL_3]]) : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = "tosa.logical_not"(%[[VAL_4]]) : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.equal %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.logical_not %[[VAL_4]] : (tensor) -> tensor // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],i1> // CHECK: } @@ -629,7 +629,7 @@ func.func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> -// CHECK: %[[VAL_7:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_6]]) : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32> // CHECK: } @@ -649,7 +649,7 @@ func.func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4 // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.bitwise_and"(%[[VAL_2]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.bitwise_and %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],si32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32> // CHECK: } @@ -664,9 +664,9 @@ func.func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32> // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.693147182> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.reciprocal"(%[[VAL_2]]) : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_4:.*]] = "tosa.log"(%[[VAL_1]]) : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = "tosa.mul"(%[[VAL_4]], %[[VAL_3]]) <{shift = 0 : i32}> : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_1]] : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_4]], %[[VAL_3]] {shift = 0 : i32} : (tensor, tensor<1x1xf32>) -> tensor // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -683,7 +683,7 @@ func.func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vt // CHECK: %[[VAL_2:.*]] = torch.constant.none // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x4xi32>}> : () -> tensor<3x4xi32> -// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xi32>) -> tensor<3x4xf32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> // CHECK: } @@ -702,7 +702,7 @@ func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> { // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor<4x3xi32>) -> tensor<4x3x1xi32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<4x3xi32>) -> tensor<4x3x1xi32> // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32> // CHECK: } @@ -719,7 +719,7 @@ func.func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !to // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,3,1],si32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32> // CHECK: %[[VAL_2:.*]] = torch.constant.int -1 -// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) <{new_shape = array}> : (tensor<4x3xi32>) -> tensor<4x3x1xi32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<4x3xi32>) -> tensor<4x3x1xi32> // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x3x1xi32> -> !torch.vtensor<[4,3,1],si32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[4,3,1],si32> // CHECK: } @@ -752,7 +752,7 @@ func.func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !to // CHECK: %[[VAL_2:.*]] = torch.constant.none // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1> : tensor<3x4xi32>}> : () -> tensor<3x4xi32> -// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_4]]) : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xi32>) -> tensor<3x4xf32> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> // CHECK: } @@ -772,7 +772,7 @@ func.func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 0.000000e+00 // CHECK: %[[VAL_3:.*]] = torch.constant.bool false -// CHECK: %[[VAL_4:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_1]] : (tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -798,10 +798,10 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_12:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_11]]) : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32> -// CHECK: %[[VAL_13:.*]] = "tosa.avg_pool2d"(%[[VAL_12]]) <{acc_type = f32, kernel = array, pad = array, stride = array}> : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_11]] : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32> +// CHECK: %[[VAL_13:.*]] = tosa.avg_pool2d %[[VAL_12]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_15:.*]] = "tosa.transpose"(%[[VAL_13]], %[[VAL_14]]) : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32> // CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32> // CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32> // CHECK: return %[[VAL_17]] : !torch.vtensor<[1,512,1,1],f32> @@ -828,9 +828,9 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> // CHECK: %[[VAL_TRUE:.*]] = torch.constant.bool true // CHECK: %[[VAL_I2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_2:.*]] = "tosa.reduce_max"(%[[VAL_1]]) <{axis = 2 : i64}> : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.argmax"(%[[VAL_1]]) <{axis = 2 : i64}> : (tensor<3x2x3xf32>) -> tensor<3x2xi64> -// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_3]]) <{new_shape = array}> : (tensor<3x2xi64>) -> tensor<3x2x1xi64> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> +// CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> // CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> // CHECK: return %[[VAL_6]] : tensor<3x2x1xf32> @@ -861,7 +861,7 @@ func.func @torch.vtensor.literal_si64$basic() -> !torch.vtensor<[1,512],si64> { // CHECK: %[[CST5:.*]] = torch.constant.int 5 // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3, 4]> : tensor<5xi64>}> : () -> tensor<5xi64> -// CHECK: %[[VAL_1:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<5xi64>) -> tensor<5xi64> +// CHECK: %[[VAL_1:.*]] = tosa.cast %[[VAL_0]] : (tensor<5xi64>) -> tensor<5xi64> // CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %1 : tensor<5xi64> -> !torch.vtensor<[5],si64> // CHECK: return %[[VAL_2]] : !torch.vtensor<[5],si64> func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> { @@ -897,11 +897,11 @@ func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],si64> { // CHECK: %[[CST0:.*]] = torch.constant.int 0 // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %[[VAL_2:.*]] = "tosa.equal"(%[[VAL_0]], %[[VAL_1]]) : (tensor, tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.logical_not"(%[[VAL_2]]) : (tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.equal %[[VAL_0]], %[[VAL_1]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.logical_not %[[VAL_2]] : (tensor) -> tensor // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x5x5xi8>}> : () -> tensor<1x1x5x5xi8> -// CHECK: %[[VAL_5:.*]] = "tosa.equal"(%[[INP]], %[[VAL_4]]) : (tensor<1x1x5x5xi8>, tensor<1x1x5x5xi8>) -> tensor<1x1x5x5xi1> -// CHECK: %[[VAL_6:.*]] = "tosa.logical_not"(%[[VAL_5]]) : (tensor<1x1x5x5xi1>) -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_5:.*]] = tosa.equal %[[INP]], %[[VAL_4]] : (tensor<1x1x5x5xi8>, tensor<1x1x5x5xi8>) -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_6:.*]] = tosa.logical_not %[[VAL_5]] : (tensor<1x1x5x5xi1>) -> tensor<1x1x5x5xi1> // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x1x5x5xi1> -> !torch.vtensor<[1,1,5,5],i1> // CHECK: return %[[VAL_7]] : !torch.vtensor<[1,1,5,5],i1> func.func @torch.aten.copy(%arg0: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtensor<[1,1,5,5],i1> { @@ -927,8 +927,8 @@ func.func @torch.aten.copy(%arg0: !torch.vtensor<[1,1,5,5],ui8>) -> !torch.vtens // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x5xi64>}> : () -> tensor<3x5xi64> -// CHECK: %[[VAL_1:.*]] = "tosa.equal"(%[[INP]], %[[VAL_0]]) : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1> -// CHECK: %[[VAL_2:.*]] = "tosa.logical_not"(%[[VAL_1]]) : (tensor<3x5xi1>) -> tensor<3x5xi1> +// CHECK: %[[VAL_1:.*]] = tosa.equal %[[INP]], %[[VAL_0]] : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1> +// CHECK: %[[VAL_2:.*]] = tosa.logical_not %[[VAL_1]] : (tensor<3x5xi1>) -> tensor<3x5xi1> // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x5xi1> -> !torch.vtensor<[3,5],i1> // CHECK: return %[[VAL_3]] : !torch.vtensor<[3,5],i1> func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vtensor<[3,5],i1> { @@ -946,7 +946,7 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten // CHECK: %[[VAL_2:.*]] = torch.constant.int 4 // CHECK: %[[VAL_3:.*]] = torch.constant.none // CHECK: %[[VAL_4:.*]] = torch.constant.bool false -// CHECK: %[[VAL_5:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<1x128xi1>) -> tensor<1x128xi64> +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x128xi1>) -> tensor<1x128xi64> // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x128xi64> -> !torch.vtensor<[1,128],si64> // CHECK: return %[[VAL_6]] : !torch.vtensor<[1,128],si64> // CHECK: } @@ -966,19 +966,19 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vten // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,4,2],si64> -> tensor<1x4x2xi64> // CHECK: %[[VAL_4:.*]] = torch.constant.int -1 // CHECK: %[[VAL_5:.*]] = torch.constant.bool false -// CHECK: %[[VAL_6:.*]] = "tosa.cast"(%[[VAL_3]]) : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32> -// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_6]]) <{new_shape = array}> : (tensor<1x4x2xi32>) -> tensor<1x4x2x1xi32> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_3]] : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x4x2xi32>) -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32> -// CHECK: %[[VAL_10:.*]] = "tosa.concat"(%[[VAL_8]], %[[VAL_9]], %[[VAL_7]]) <{axis = 3 : i64}> : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32> -// CHECK: %[[VAL_11:.*]] = "tosa.reshape"(%[[VAL_2]]) <{new_shape = array}> : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32> -// CHECK: %[[VAL_12:.*]] = "tosa.reshape"(%[[VAL_10]]) <{new_shape = array}> : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32> +// CHECK: %[[VAL_10:.*]] = tosa.concat %[[VAL_8]], %[[VAL_9]], %[[VAL_7]] {axis = 3 : i32} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32> // CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_14:.*]] = "tosa.mul"(%[[VAL_12]], %[[VAL_13]]) <{shift = 0 : i32}> : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32> -// CHECK: %[[VAL_15:.*]] = "tosa.reduce_sum"(%[[VAL_14]]) <{axis = 1 : i64}> : (tensor<8x3xi32>) -> tensor<8x1xi32> -// CHECK: %[[VAL_16:.*]] = "tosa.reshape"(%[[VAL_15]]) <{new_shape = array}> : (tensor<8x1xi32>) -> tensor<1x8xi32> -// CHECK: %[[VAL_17:.*]] = "tosa.gather"(%[[VAL_11]], %[[VAL_16]]) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> -// CHECK: %[[VAL_18:.*]] = "tosa.reshape"(%[[VAL_17]]) <{new_shape = array}> : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> +// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i32} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<8x1xi32>) -> tensor<1x8xi32> +// CHECK: %[[VAL_17:.*]] = tosa.gather %[[VAL_11]], %[[VAL_16]] : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> // CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> // CHECK: return %[[VAL_19]] : !torch.vtensor<[1,4,2],f32> // CHECK: } @@ -997,9 +997,9 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32> // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> -// CHECK: %[[VAL_7:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_6]]) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> -// CHECK: %[[VAL_8:.*]] = "tosa.cast"(%[[VAL_7]]) : (tensor<2x2xi32>) -> tensor<2x2xi64> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i32} : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> +// CHECK: %[[VAL_7:.*]] = tosa.add %[[VAL_2]], %[[VAL_6]] : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> +// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<2x2xi32>) -> tensor<2x2xi64> // CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x2xi64> -> !torch.vtensor<[2,2],si64> // CHECK: return %[[VAL_9]] : !torch.vtensor<[2,2],si64> // CHECK: } @@ -1017,10 +1017,10 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torc // CHECK: %[[VAL_3:.*]] = torch.constant.int 256 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<256> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_4]], %[[VAL_5]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32> -// CHECK: %[[VAL_8:.*]] = "tosa.add"(%[[VAL_7]], %[[VAL_6]]) : (tensor<1x1x128x128xi32>, tensor) -> tensor<1x1x128x128xi32> -// CHECK: %[[VAL_9:.*]] = "tosa.cast"(%[[VAL_8]]) : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_5]] {shift = 0 : i32} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32> +// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor<1x1x128x128xi32>, tensor) -> tensor<1x1x128x128xi32> +// CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_8]] : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64> // CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> // CHECK: return %[[VAL_10]] : !torch.vtensor<[1,1,128,128],si64> // CHECK: } @@ -1037,7 +1037,7 @@ func.func @torch.aten.Scalar$basic(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64> // CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK: %[[VAL_3:.*]] = torch.constant.int 511 -// CHECK: %[[VAL_4:.*]] = "tosa.clamp"(%[[VAL_1]]) <{max_fp = 5.110000e+02 : f32, max_int = 511 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64}> : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 5.110000e+02 : f32, max_int = 511 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> // CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],si64> // CHECK: } @@ -1056,8 +1056,8 @@ func.func @torch.aten.clamp(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1> // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.cast"(%[[VAL_5]]) : (tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.select"(%[[VAL_3]], %[[VAL_6]], %[[VAL_2]]) : (tensor<1x1x128x128xi1>, tensor, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_3]], %[[VAL_6]], %[[VAL_2]] : (tensor<1x1x128x128xi1>, tensor, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } @@ -1075,7 +1075,7 @@ func.func @torch.aten.masked_fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f3 // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32> // CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1> // CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.select"(%[[VAL_4]], %[[VAL_5]], %[[VAL_3]]) : (tensor<1x1x128x128xi1>, tensor, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_4]], %[[VAL_5]], %[[VAL_3]] : (tensor<1x1x128x128xi1>, tensor, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } @@ -1088,7 +1088,7 @@ func.func @torch.aten.masked_fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f3 // CHECK-LABEL: func.func @torch.aten.abs( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[15,15],si64>) -> !torch.vtensor<[15,15],si64> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[15,15],si64> -> tensor<15x15xi64> -// CHECK: %[[VAL_2:.*]] = "tosa.abs"(%[[VAL_1]]) : (tensor<15x15xi64>) -> tensor<15x15xi64> +// CHECK: %[[VAL_2:.*]] = tosa.abs %[[VAL_1]] : (tensor<15x15xi64>) -> tensor<15x15xi64> // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<15x15xi64> -> !torch.vtensor<[15,15],si64> // CHECK: return %[[VAL_3]] : !torch.vtensor<[15,15],si64> // CHECK: } @@ -1105,7 +1105,7 @@ func.func @torch.aten.abs(%arg0: !torch.vtensor<[15,15],si64>) -> !torch.vtensor // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1> // CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32> // CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.select"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor) -> tensor<1x12x5x5xf32> +// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor) -> tensor<1x12x5x5xf32> // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,5,5],f32> // CHECK: } @@ -1116,15 +1116,15 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to // ----- // CHECK-LABEL: func.func @torch.aten.remainder.Scalar( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> { // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> // CHECK: %[[VAL_4:.*]] = torch.constant.int 2 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = "tosa.reciprocal"(%[[VAL_5:.*]]) : (tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = "tosa.mul"(%[[VAL_3:.*]], %[[VAL_6:.*]]) <{shift = 0 : i32}> : (tensor<2x4xf32>, tensor) -> tensor<2x4xf32> -// CHECK: %[[VAL_8:.*]] = "tosa.floor"(%[[VAL_7]]) : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_9:.*]] = "tosa.mul"(%[[VAL_5]], %[[VAL_8]]) <{shift = 0 : i32}> : (tensor, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_10:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_9]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reciprocal %[[VAL_5:.*]] : (tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3:.*]], %[[VAL_6:.*]] {shift = 0 : i32} : (tensor<2x4xf32>, tensor) -> tensor<2x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.floor %[[VAL_7]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_5]], %[[VAL_8]] {shift = 0 : i32} : (tensor, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_10:.*]] = tosa.sub %[[VAL_3]], %[[VAL_9]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[2,4],f32> // CHECK: } diff --git a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir index 8c9420e7361a..c1d1d915b017 100644 --- a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir +++ b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: torch.aten.mul.Scalar$mixed_type // CHECK-SAME: %[[VAL_0:.*]]: tensor<5xbf16> // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xbf16>}> : () -> tensor<1xbf16> -// CHECK: %[[VAL_2:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_1]]) <{shift = 0 : i32}> : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16> +// CHECK: %[[VAL_2:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i32} : (tensor<5xbf16>, tensor<1xbf16>) -> tensor<5xbf16> func.func @torch.aten.mul.Scalar$mixed_type(%arg0: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],bf16> { %float2.000000e00 = torch.constant.float 2.000000e+00 %0 = torch.aten.mul.Scalar %arg0, %float2.000000e00 : !torch.vtensor<[5],bf16>, !torch.float -> !torch.vtensor<[5],bf16> @@ -15,8 +15,8 @@ func.func @torch.aten.mul.Scalar$mixed_type(%arg0: !torch.vtensor<[5],bf16>) -> // CHECK-LABEL: torch.aten.add.Tensor$mixed_type_fp // CHECK-SAME: %[[VAL_0:.*]]: tensor<6xbf16> // CHECK-SAME: %[[VAL_1:.*]]: tensor<6xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<6xf32>) -> tensor<6xbf16> -// CHECK: %[[VAL_4:.*]] = "tosa.add"(%[[VAL_0]], %[[VAL_3]]) : (tensor<6xbf16>, tensor<6xbf16>) -> tensor<6xbf16> +// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_1]] : (tensor<6xf32>) -> tensor<6xbf16> +// CHECK: %[[VAL_4:.*]] = tosa.add %[[VAL_0]], %[[VAL_3]] : (tensor<6xbf16>, tensor<6xbf16>) -> tensor<6xbf16> func.func @torch.aten.add.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[6],bf16>, %arg1: !torch.vtensor<[6],f32>, %arg2: !torch.float) -> !torch.vtensor<[6],bf16> { %float1 = torch.constant.float 1.000000e+00 %0 = torch.aten.add.Tensor %arg0, %arg1, %float1 : !torch.vtensor<[6],bf16>, !torch.vtensor<[6],f32>, !torch.float -> !torch.vtensor<[6],bf16> @@ -28,8 +28,8 @@ func.func @torch.aten.add.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[6],bf16>, // CHECK-LABEL: torch.aten.add.Tensor$mixed_type_int // CHECK-SAME: %[[VAL_0:.*]]: tensor<5xf32> // CHECK-SAME: %[[VAL_1:.*]]: tensor<5xbf16> -// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<5xbf16>) -> tensor<5xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_0]], %[[VAL_2]]) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<5xbf16>) -> tensor<5xf32> +// CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_0]], %[[VAL_2]] : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> func.func @torch.aten.add.Tensor$mixed_type_int(%arg0: !torch.vtensor<[5],f32>, %arg1: !torch.vtensor<[5],bf16>) -> !torch.vtensor<[5],f32> { %int1 = torch.constant.int 1 %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[5],f32>, !torch.vtensor<[5],bf16>, !torch.int -> !torch.vtensor<[5],f32> @@ -41,8 +41,8 @@ func.func @torch.aten.add.Tensor$mixed_type_int(%arg0: !torch.vtensor<[5],f32>, // CHECK-LABEL: torch.aten.Scalar$mixed_type // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x32x64xi16> // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<256> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32> -// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<1x1x32x64xi16>) -> tensor<1x1x32x64xi32> -// CHECK: %[[VAL_3:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_1]]) : (tensor<1x1x32x64xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x32x64xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_0]] : (tensor<1x1x32x64xi16>) -> tensor<1x1x32x64xi32> +// CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_2]], %[[VAL_1]] : (tensor<1x1x32x64xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x32x64xi32> func.func @torch.aten.Scalar$mixed_type(%arg0: !torch.vtensor<[1,1,32,64],si16>) -> !torch.vtensor<[1,1,32,64],si32> { %int1 = torch.constant.int 1 %int256 = torch.constant.int 256 @@ -55,7 +55,7 @@ func.func @torch.aten.Scalar$mixed_type(%arg0: !torch.vtensor<[1,1,32,64],si16>) // CHECK-LABEL: torch.aten.sub.Scalar$mixed_type // CHECK-SAME: %[[VAL_0:.*]]: tensor, // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.sub"(%[[VAL_0]], %[[VAL_2]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.sub %[[VAL_0]], %[[VAL_2]] : (tensor, tensor) -> tensor func.func @torch.aten.sub.Scalar$mixed_type(%arg0: !torch.vtensor<[],bf16>, %arg1: !torch.vtensor<[],bf16>) -> !torch.vtensor<[],bf16> { %int1 = torch.constant.int 1 %0 = torch.aten.sub.Scalar %arg0, %int1, %int1 : !torch.vtensor<[],bf16>, !torch.int, !torch.int -> !torch.vtensor<[],bf16> @@ -67,8 +67,8 @@ func.func @torch.aten.sub.Scalar$mixed_type(%arg0: !torch.vtensor<[],bf16>, %arg // CHECK-LABEL: torch.aten.maximum$mixed_type // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x1xi32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3x1xf32> -// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<1x3x1xi32>) -> tensor<1x3x1xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.maximum"(%[[VAL_2]], %[[VAL_1]]) : (tensor<1x3x1xf32>, tensor<1x3x1xf32>) -> tensor<1x3x1xf32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_0]] : (tensor<1x3x1xi32>) -> tensor<1x3x1xf32> +// CHECK: %[[VAL_3:.*]] = tosa.maximum %[[VAL_2]], %[[VAL_1]] : (tensor<1x3x1xf32>, tensor<1x3x1xf32>) -> tensor<1x3x1xf32> func.func @torch.aten.maximum$mixed_type(%arg0: !torch.vtensor<[1,3,1],si32>, %arg1: !torch.vtensor<[1,3,1],f32>) -> !torch.vtensor<[1,3,1],f32> { %0 = torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[1,3,1],si32>, !torch.vtensor<[1,3,1],f32> -> !torch.vtensor<[1,3,1],f32> return %0 : !torch.vtensor<[1,3,1],f32> @@ -79,8 +79,8 @@ func.func @torch.aten.maximum$mixed_type(%arg0: !torch.vtensor<[1,3,1],si32>, %a // CHECK-LABEL: torch.aten.bitwise_and.Tensor$mixed_type // CHECK-SAME: %[[VAL_0:.*]]: tensor, // CHECK-SAME: %[[VAL_1:.*]]: tensor -// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.bitwise_and"(%[[VAL_2]], %[[VAL_1]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_0]] : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.bitwise_and %[[VAL_2]], %[[VAL_1]] : (tensor, tensor) -> tensor func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],si16>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { %0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],si16>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32> return %0 : !torch.vtensor<[?,?],si32> @@ -91,9 +91,9 @@ func.func @torch.aten.bitwise_and.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?], // CHECK-LABEL: torch.aten.div.Tensor$mixed_type_fp // CHECK-SAME: %[[VAL_0:.*]]: tensor, // CHECK-SAME: %[[VAL_1:.*]]: tensor -// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.reciprocal"(%[[VAL_2]]) : (tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.mul"(%[[VAL_0]], %[[VAL_3]]) <{shift = 0 : i32}> : (tensor, tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_0]], %[[VAL_3]] {shift = 0 : i32} : (tensor, tensor) -> tensor func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],f32> { %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],f32> return %0 : !torch.vtensor<[?, ?],f32> @@ -104,8 +104,8 @@ func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32> // CHECK-LABEL: torch.aten.div.Tensor$mixed_type_int // CHECK-SAME: %[[VAL_0:.*]]: tensor, // CHECK-SAME: %[[VAL_1:.*]]: tensor -// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.div"(%[[VAL_2]], %[[VAL_1]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_0]] : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.div %[[VAL_2]], %[[VAL_1]] : (tensor, tensor) -> tensor func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> { %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32> return %0 : !torch.vtensor<[?, ?],si32> @@ -116,8 +116,8 @@ func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si1 // CHECK-LABEL: torch.aten.pow.Tensor$mixed_type // CHECK-SAME: %[[VAL_0:.*]]: tensor // CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> -// CHECK: %[[VAL_2:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = "tosa.pow"(%[[VAL_2]], %[[VAL_1]]) : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_0]] : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.pow %[[VAL_2]], %[[VAL_1]] : (tensor, tensor<1x1xf32>) -> tensor func.func @torch.aten.pow.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f32> { %fp0 = torch.constant.float 3.123400e+00 %0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f16>, !torch.float -> !torch.vtensor<[?,?],f32> diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index 254a348cdec4..f22d5b785746 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -25,7 +25,7 @@ torch.class_type @c { } %c0 = torch.constant.int 0 %0 = torch.nn_module { - // expected-error @+1 {{'torch.slot' op is expected to match type and name of '"torch.attr"() {name = "g", type = !torch.int} : () -> ()}} + // expected-error @+1 {{'torch.slot' op is expected to match type and name of '"torch.attr"() <{name = "g", type = !torch.int}> : () -> ()}} torch.slot "f", %c0 : !torch.int } : !torch.nn.Module<"c"> diff --git a/test/Dialect/TorchConversion/verify-tosa-backend-contract.mlir b/test/Dialect/TorchConversion/verify-tosa-backend-contract.mlir index 2a55a3231548..c489375268b9 100644 --- a/test/Dialect/TorchConversion/verify-tosa-backend-contract.mlir +++ b/test/Dialect/TorchConversion/verify-tosa-backend-contract.mlir @@ -2,7 +2,7 @@ // CHECK: func.func @tanh func.func @tanh(%arg0: tensor) -> tensor { - %0 = "tosa.tanh"(%arg0) : (tensor) -> tensor + %0 = tosa.tanh %arg0 : (tensor) -> tensor return %0 : tensor }