From 54cf76e755e30915ca77dff13050e8620094060d Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Wed, 15 Jan 2025 18:02:14 +0000 Subject: [PATCH] fix(ONNX): avoids resizing unsupported dimensions --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 145 +++++++++++++++++- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 6 +- 2 files changed, 144 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 0d5417d1a9498..48e1fbb8f2996 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -180,6 +180,77 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter, return success(); } +Value scaleIdentityComparisonOpForFactorAtDimensionIn( + Value givenScaleFactors, int64_t givenDimension, OpBinder binder, + ConversionPatternRewriter &rewriter) { + auto typeOfScaleFactors = + cast(givenScaleFactors.getType()); + + Type typeOfSelectionFromScaleFactors = + typeOfScaleFactors.getWithSizesAndDtype( + ArrayRef{1}, typeOfScaleFactors.getOptionalDtype()); + + auto opLocation = binder.getLoc(); + + Value zeroAsOp = rewriter.create( + opLocation, rewriter.getI64IntegerAttr(0)); + + Value scaleIdentityAsOp = rewriter.create( + opLocation, rewriter.getF64FloatAttr(1.0)); + + Value givenDimensionAsOp = rewriter.create( + opLocation, rewriter.getI64IntegerAttr(givenDimension)); + + Type typeOfScaleFactor = rewriter.getType(); + + Value selectionFromScaleFactorsAsOp = rewriter.create( + opLocation, typeOfSelectionFromScaleFactors, givenScaleFactors, zeroAsOp, + givenDimensionAsOp); + + Value scaleFactorAsOp = rewriter.create( + opLocation, typeOfScaleFactor, selectionFromScaleFactorsAsOp); + + Type typeOfComparisonResult = rewriter.getType(); + + return rewriter.create( + opLocation, typeOfComparisonResult, scaleFactorAsOp, scaleIdentityAsOp); +} + +Value originalSizeComparisonOpForSizeAtDimensionIn( + Value givenTargetSizes, Value givenOriginalTensor, int64_t givenDimension, + OpBinder binder, ConversionPatternRewriter &rewriter) { + auto typeOfTargetSizes = + cast(givenTargetSizes.getType()); + + Type typeOfSelectionFromTargetSizes = typeOfTargetSizes.getWithSizesAndDtype( + ArrayRef{1}, typeOfTargetSizes.getOptionalDtype()); + + auto opLocation = binder.getLoc(); + + Value zeroAsOp = rewriter.create( + opLocation, rewriter.getI64IntegerAttr(0)); + + Type typeOfTargetSize = rewriter.getType(); + + Value givenDimensionAsOp = rewriter.create( + opLocation, rewriter.getI64IntegerAttr(givenDimension)); + + Value selectionFromTargetSizesAsOp = rewriter.create( + opLocation, typeOfSelectionFromTargetSizes, givenTargetSizes, zeroAsOp, + givenDimensionAsOp); + + Value targetSizeAsOp = rewriter.create( + opLocation, typeOfTargetSize, selectionFromTargetSizesAsOp); + + Value originalSizeAsOp = rewriter.create( + opLocation, givenOriginalTensor, givenDimensionAsOp); + + Type typeOfComparisonResult = rewriter.getType(); + + return rewriter.create(opLocation, typeOfComparisonResult, + targetSizeAsOp, originalSizeAsOp); +} + Value withUnsupportedDimensionsFilteredOut( Value givenTransformationVector, OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -2724,6 +2795,43 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "round_prefer_floor") || binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75)) return failure(); + + Value inputTensor = operands[0]; + auto typeOfInputTensor = + cast(inputTensor.getType()); + + auto sizesOfInputTensor = typeOfInputTensor.getSizes(); + ArrayRef sizesOfOutputTensor = typeOfOutputTensor.getSizes(); + + int64_t const dimensionAssumedToBeBatch = 0; + int64_t const dimensionAssumedToBeChannel = 1; + int64_t nonResizableDimensions[] = { + dimensionAssumedToBeBatch, + dimensionAssumedToBeChannel, + }; + + auto unknownSize = Torch::kUnknownSize; + + // Compile-time check for dimensions of static size + for (auto eachDimension : nonResizableDimensions) { + auto eachSizeOfInputTensor = sizesOfInputTensor[eachDimension]; + auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDimension]; + + if (eachSizeOfInputTensor == unknownSize || + eachSizeOfOutputTensor == unknownSize) { + continue; + } else if (eachSizeOfInputTensor == eachSizeOfOutputTensor) { + continue; + } + + auto scalingIntentErrorMessage = + "unsupported: non-trivial intent to scale dimension: " + + std::to_string(eachDimension); + + return rewriter.notifyMatchFailure(binder.op, + scalingIntentErrorMessage); + }; + if (antialias != 0) { return rewriter.notifyMatchFailure( binder.op, @@ -2773,10 +2881,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.create(opLocation, modeStr); } - Value inputTensor = operands[0]; - auto typeOfInputTensor = - cast(inputTensor.getType()); - auto sizesOfInputTensor = typeOfInputTensor.getSizes(); unsigned rankOfInputTensor = sizesOfInputTensor.size(); // supported modes: @@ -2824,10 +2928,43 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (operands.size() < 4) { Value proposedScaleFactorsAsOp = operands[2]; + + // run-time scale factor check for dynamic sizes + for (auto eachDimension : nonResizableDimensions) { + auto eachScaleIdentityComparisonAsOp = + scaleIdentityComparisonOpForFactorAtDimensionIn( + proposedScaleFactorsAsOp, eachDimension, binder, rewriter); + + auto eachErrorMessage = + "Unsupported: non-trivial scale factor for dimension " + + std::to_string(eachDimension); + + rewriter.create( + opLocation, eachScaleIdentityComparisonAsOp, + rewriter.getStringAttr(eachErrorMessage)); + }; + filteredScaleFactorsAsOp = withUnsupportedDimensionsFilteredOut( proposedScaleFactorsAsOp, binder, rewriter); } else { Value proposedSizesAsOp = operands[3]; + + // run-time target size check for dynamic sizes + for (auto eachDimension : nonResizableDimensions) { + auto eachSizeComparisonAsOp = + originalSizeComparisonOpForSizeAtDimensionIn( + proposedSizesAsOp, inputTensor, eachDimension, binder, + rewriter); + + auto eachErrorMessage = + "Unsupported: non-trivial resizing of dimension " + + std::to_string(eachDimension); + + rewriter.create( + opLocation, eachSizeComparisonAsOp, + rewriter.getStringAttr(eachErrorMessage)); + }; + filteredSizesAsOp = withUnsupportedDimensionsFilteredOut( proposedSizesAsOp, binder, rewriter); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 16c86218dbc8b..1e1d09def615c 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2256,7 +2256,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // CHECK-LABEL: func.func @test_resize_sizes_nearest func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %12, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } @@ -2267,7 +2267,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none // CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %12, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { torch.onnx.coordinate_transformation_mode = "half_pixel", torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> @@ -2280,7 +2280,7 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %12, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> }