diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index c01dec62c2569..a03277b263b92 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -234,6 +234,62 @@ indexOfFrontmostNonTrivialScaleFactorIn(Value givenVector, OpBinder binder, return std::nullopt; } +std::optional<int64_t> indexOfFrontmostNonTrivialTargetSizeIn( + Value givenTargetSizeVector, Value givenOriginalTensor, OpBinder binder, + ConversionPatternRewriter &rewriter) { + Torch::BaseTensorType typeOfOriginalTensor = + cast<Torch::BaseTensorType>(givenOriginalTensor.getType()); + auto sizesOfOriginalTensor = typeOfOriginalTensor.getSizes(); + + Torch::BaseTensorType typeOfTargetSizeVector = + cast<Torch::BaseTensorType>(givenTargetSizeVector.getType()); + + SmallVector<int64_t> sizesOfIndexPath; + sizesOfIndexPath.push_back(1); + Type typeOfIndexPath = typeOfTargetSizeVector.getWithSizesAndDtype( + llvm::ArrayRef(sizesOfIndexPath), + typeOfTargetSizeVector.getOptionalDtype()); + + auto opLocation = binder.getLoc(); + + Value zeroAsOp = rewriter.create<Torch::ConstantIntOp>( + opLocation, rewriter.getI64IntegerAttr(0)); + + Type typeOfEveryTargetSize = rewriter.getType<Torch::IntType>(); + + for (auto [eachDimension, eachOriginalSize] : + llvm::enumerate(sizesOfOriginalTensor)) { + Value eachDimensionAsOp = rewriter.create<Torch::ConstantIntOp>( + opLocation, rewriter.getType<Torch::IntType>(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), eachDimension)); + + Value indexPathOfEachTargetSize = rewriter.create<Torch::AtenSelectIntOp>( + opLocation, typeOfIndexPath, givenTargetSizeVector, zeroAsOp, + eachDimensionAsOp); + + Value eachTargetSizeAsOp = rewriter.create<Torch::AtenItemOp>( + opLocation, typeOfEveryTargetSize, indexPathOfEachTargetSize); + + Value eachOriginalSizeAsOp = rewriter.create<Torch::ConstantIntOp>( + opLocation, rewriter.getI64IntegerAttr(eachOriginalSize)); + + Value eachSizeComparisonAsOp = rewriter.create<Torch::AtenEqIntOp>( + opLocation, rewriter.getType<Torch::BoolType>(), eachTargetSizeAsOp, + eachOriginalSizeAsOp); + + bool eachTargetSizeMatchesOriginal; + + matchPattern(eachSizeComparisonAsOp, + Torch::m_TorchConstantBool(&eachTargetSizeMatchesOriginal)); + + if (!eachTargetSizeMatchesOriginal) { + return static_cast<int64_t>(eachDimension); + } + } + + return std::nullopt; +} + Value withUnsupportedDimensionsFilteredOut( Value givenTransformationVector, OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -2933,6 +2989,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( proposedScalingsForOutputTensor, binder, rewriter); } else { Value proposedSizesForOutputTensor = operands[3]; + + auto indexOfNonTrivialTargetSize = + indexOfFrontmostNonTrivialTargetSizeIn( + proposedSizesForOutputTensor, inputTensor, binder, rewriter); + + if (indexOfNonTrivialTargetSize == dimensionAssumedToBeBatch || + indexOfNonTrivialTargetSize == dimensionAssumedToBeChannel) { + return rewriter.notifyMatchFailure( + binder.op, + "Unsupported: non-trivial resizing of dimension " + + std::to_string(indexOfNonTrivialTargetSize.value())); + } + filteredSizesForOutputTensor = withUnsupportedDimensionsFilteredOut( proposedSizesForOutputTensor, binder, rewriter); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 16c86218dbc8b..6779a9011cb70 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2256,7 +2256,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // CHECK-LABEL: func.func @test_resize_sizes_nearest func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %16, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } @@ -2267,7 +2267,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none // CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %16, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { torch.onnx.coordinate_transformation_mode = "half_pixel", torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> @@ -2280,7 +2280,7 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %16, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> }