Skip to content

Commit

Permalink
fix(ONNX): avoids resizing non targetable dimensions when dynamically…
Browse files Browse the repository at this point in the history
… defined
  • Loading branch information
bjacobgordon committed Jan 15, 2025
1 parent 5c47870 commit cb20894
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 3 deletions.
69 changes: 69 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,62 @@ indexOfFrontmostNonTrivialScaleFactorIn(Value givenVector, OpBinder binder,
return std::nullopt;
}

std::optional<int64_t> indexOfFrontmostNonTrivialTargetSizeIn(
Value givenTargetSizeVector, Value givenOriginalTensor, OpBinder binder,
ConversionPatternRewriter &rewriter) {
Torch::BaseTensorType typeOfOriginalTensor =
cast<Torch::BaseTensorType>(givenOriginalTensor.getType());
auto sizesOfOriginalTensor = typeOfOriginalTensor.getSizes();

Torch::BaseTensorType typeOfTargetSizeVector =
cast<Torch::BaseTensorType>(givenTargetSizeVector.getType());

SmallVector<int64_t> sizesOfIndexPath;
sizesOfIndexPath.push_back(1);
Type typeOfIndexPath = typeOfTargetSizeVector.getWithSizesAndDtype(
llvm::ArrayRef(sizesOfIndexPath),
typeOfTargetSizeVector.getOptionalDtype());

auto opLocation = binder.getLoc();

Value zeroAsOp = rewriter.create<Torch::ConstantIntOp>(
opLocation, rewriter.getI64IntegerAttr(0));

Type typeOfEveryTargetSize = rewriter.getType<Torch::IntType>();

for (auto [eachDimension, eachOriginalSize] :
llvm::enumerate(sizesOfOriginalTensor)) {
Value eachDimensionAsOp = rewriter.create<Torch::ConstantIntOp>(
opLocation, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), eachDimension));

Value indexPathOfEachTargetSize = rewriter.create<Torch::AtenSelectIntOp>(
opLocation, typeOfIndexPath, givenTargetSizeVector, zeroAsOp,
eachDimensionAsOp);

Value eachTargetSizeAsOp = rewriter.create<Torch::AtenItemOp>(
opLocation, typeOfEveryTargetSize, indexPathOfEachTargetSize);

Value eachOriginalSizeAsOp = rewriter.create<Torch::ConstantIntOp>(
opLocation, rewriter.getI64IntegerAttr(eachOriginalSize));

Value eachSizeComparisonAsOp = rewriter.create<Torch::AtenEqIntOp>(
opLocation, rewriter.getType<Torch::BoolType>(), eachTargetSizeAsOp,
eachOriginalSizeAsOp);

bool eachTargetSizeMatchesOriginal;

matchPattern(eachSizeComparisonAsOp,
Torch::m_TorchConstantBool(&eachTargetSizeMatchesOriginal));

if (!eachTargetSizeMatchesOriginal) {
return static_cast<int64_t>(eachDimension);
}
}

return std::nullopt;
}

Value withUnsupportedDimensionsFilteredOut(
Value givenTransformationVector, OpBinder binder,
ConversionPatternRewriter &rewriter) {
Expand Down Expand Up @@ -2933,6 +2989,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
proposedScalingsForOutputTensor, binder, rewriter);
} else {
Value proposedSizesForOutputTensor = operands[3];

auto indexOfNonTrivialTargetSize =
indexOfFrontmostNonTrivialTargetSizeIn(
proposedSizesForOutputTensor, inputTensor, binder, rewriter);

if (indexOfNonTrivialTargetSize == dimensionAssumedToBeBatch ||
indexOfNonTrivialTargetSize == dimensionAssumedToBeChannel) {
return rewriter.notifyMatchFailure(
binder.op,
"Unsupported: non-trivial resizing of dimension " +
std::to_string(indexOfNonTrivialTargetSize.value()));
}

filteredSizesForOutputTensor = withUnsupportedDimensionsFilteredOut(
proposedSizesForOutputTensor, binder, rewriter);
}
Expand Down
6 changes: 3 additions & 3 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2256,7 +2256,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
// CHECK-LABEL: func.func @test_resize_sizes_nearest
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %16, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}
Expand All @@ -2267,7 +2267,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor"
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %16, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
torch.onnx.coordinate_transformation_mode = "half_pixel",
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
Expand All @@ -2280,7 +2280,7 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],
f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %16, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}
Expand Down

0 comments on commit cb20894

Please sign in to comment.