Skip to content

Commit

Permalink
fix(ONNX): avoids resizing non scalable dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
bjacobgordon committed Jan 13, 2025
1 parent 17029e6 commit 574f4fe
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 7 deletions.
65 changes: 61 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2749,8 +2749,69 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, "unimplemented: cubic coeff must be -0.75");
}

Value inputTensor = operands[0];
Torch::ValueTensorType inputTensor_type =
cast<Torch::ValueTensorType>(inputTensor.getType());
ArrayRef<int64_t> inputTensor_sizes = inputTensor_type.getSizes();
ArrayRef<int64_t> outputTensor_sizes = outputTensor_type.getSizes();

int64_t const batchDimension = 0;
int64_t const channelDimension = 1;
int64_t nonScalableDimensions[] = {
batchDimension,
channelDimension,
};

auto errorMessageForScaling = [](int64_t givenDimension) {
switch (givenDimension) {
case batchDimension:
return "Unexpected intent to scale the batch dimension";
case channelDimension:
return "Unexpected intent to scale the channel dimension";
default:
return "Scalable dimension treated as non-scalable";
}
};

auto unknownSize = Torch::kUnknownSize;

// Compile-time check for dimensions of static size
for (auto eachDimension : nonScalableDimensions) {
auto eachInputSize = inputTensor_sizes[eachDimension];
auto eachOutputSize = outputTensor_sizes[eachDimension];

if (eachInputSize == unknownSize || eachOutputSize == unknownSize) {
continue;
} else if (eachInputSize == eachOutputSize) {
continue;
}

return rewriter.notifyMatchFailure(
binder.op, errorMessageForScaling(eachDimension));
}

auto binderLocation = binder.getLoc();

// Run-time check for dimensions of dynamic size
for (auto eachDimension : nonScalableDimensions) {
auto eachDimensionAsValue = rewriter.create<Torch::ConstantIntOp>(
binderLocation, rewriter.getI64IntegerAttr(eachDimension));

Value eachInputSizeAsValue = rewriter.create<Torch::AtenSizeIntOp>(
binderLocation, inputTensor, eachDimensionAsValue);

int64_t eachOutputSize = outputTensor_sizes[eachDimension];
Value eachOutputSizeAsValue = rewriter.create<Torch::ConstantIntOp>(
binderLocation, rewriter.getI64IntegerAttr(eachOutputSize));

Value eachSizeComparison = rewriter.create<Torch::AtenEqIntOp>(
binderLocation, eachInputSizeAsValue, eachOutputSizeAsValue);

rewriter.create<Torch::RuntimeAssertOp>(
binderLocation, eachSizeComparison,
rewriter.getStringAttr(errorMessageForScaling(eachDimension)));
};

Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binderLocation, false);
Value cstTrue =
Expand All @@ -2770,10 +2831,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
rewriter.create<Torch::ConstantStrOp>(binderLocation, modeStr);
}

Value inputTensor = operands[0];
Torch::ValueTensorType inputTensor_type =
cast<Torch::ValueTensorType>(inputTensor.getType());
ArrayRef<int64_t> inputTensor_sizes = inputTensor_type.getSizes();
unsigned inputTensor_rank = inputTensor_sizes.size();

// supported modes:
Expand Down
6 changes: 3 additions & 3 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2256,7 +2256,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
// CHECK-LABEL: func.func @test_resize_sizes_nearest
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %str, %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}
Expand All @@ -2267,7 +2267,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor"
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %[[STR]], %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
torch.onnx.coordinate_transformation_mode = "half_pixel",
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
Expand All @@ -2280,7 +2280,7 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],
f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %str, %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}
Expand Down

0 comments on commit 574f4fe

Please sign in to comment.