Skip to content

Commit

Permalink
fix(ONNX): avoids resizing conventionally fixed dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
bjacobgordon committed Jan 9, 2025
1 parent 76368bd commit 7aec80b
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 15 deletions.
4 changes: 4 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ class BaseTensorType : public Type {
/// Enable isa/dyn_cast for BaseTensorType.
static bool classof(Type type);

/// The element-wise comparison of each dimension/size in `that` tensor
std::vector<std::optional<bool>>
shapeComparisonAgainst(BaseTensorType that) const;

/// Return true if this type has the same sizes and dtype as the other.
bool hasSameSizesAndDtype(BaseTensorType other) const;

Expand Down
29 changes: 26 additions & 3 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2717,6 +2717,32 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"round_prefer_floor") ||
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
return failure();

Value inputTensor = operands[0];
Torch::ValueTensorType inputTensor_blueprint =
cast<Torch::ValueTensorType>(inputTensor.getType());

std::vector<std::optional<bool>> shapeComparison =
inputTensor_blueprint.shapeComparisonAgainst(
outputTensor_blueprint);

// Comparisons of the dimensions assumed to carry the batch and channel
auto shapeComparisonForFixedDimensions =
ArrayRef(shapeComparison).take_front(2);

for (auto eachDimensionComparison : shapeComparisonForFixedDimensions) {
if (eachDimensionComparison == std::nullopt) {
return rewriter.notifyMatchFailure(
binder.op, "Sizes for batch and channel dimensions must be "
"statically defined");
}
if (eachDimensionComparison == false) {
return rewriter.notifyMatchFailure(
binder.op,
"Unexpected intent to resize the batch/channel dimensions");
}
};

if (antialias != 0) {
return rewriter.notifyMatchFailure(
binder.op,
Expand Down Expand Up @@ -2749,9 +2775,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, "unimplemented: cubic coeff must be -0.75");
}

Value inputTensor = operands[0];
Torch::ValueTensorType inputTensor_blueprint =
cast<Torch::ValueTensorType>(inputTensor.getType());
ArrayRef<int64_t> inputTensor_dimensions =
inputTensor_blueprint.getSizes();
unsigned rank = inputTensor_dimensions.size();
Expand Down
24 changes: 24 additions & 0 deletions lib/Dialect/Torch/IR/TorchTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,30 @@ static bool isValidTorchDtype(Type dtype) {
return false;
}

std::vector<std::optional<bool>>
BaseTensorType::shapeComparisonAgainst(BaseTensorType that) const {
auto this_dimensions = /**/ getSizes();
auto that_dimensions = that.getSizes();

auto this_rank = this_dimensions.size();
auto that_rank = that_dimensions.size();

assert((this_rank == that_rank) && "Ranks must match to compare dimensions");

std::vector<std::optional<bool>> runningComparison = {};
auto dimensionPairs = llvm::zip(this_dimensions, that_dimensions);

for (auto [eachLHDimension, eachRHDimension] : dimensionPairs) {
if (eachLHDimension == kUnknownSize || eachRHDimension == kUnknownSize) {
runningComparison.push_back(std::nullopt);
} else {
runningComparison.push_back(eachLHDimension == eachRHDimension);
}
}

return runningComparison;
}

bool BaseTensorType::hasSameSizesAndDtype(BaseTensorType other) const {
return getOptionalSizes() == other.getOptionalSizes() &&
getOptionalDtype() == other.getOptionalDtype();
Expand Down
24 changes: 12 additions & 12 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2254,35 +2254,35 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
// -----

// CHECK-LABEL: func.func @test_resize_sizes_nearest
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,1,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32>
return %0 : !torch.vtensor<[1,1,?,?],f32>
}

// -----

// CHECK-LABEL: func.func @test_resize_sizes_nearest
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor"
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,1,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
torch.onnx.coordinate_transformation_mode = "half_pixel",
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32>
return %0 : !torch.vtensor<[1,1,?,?],f32>
}

// -----

// CHECK-LABEL: func.func @test_resize_sizes_linear
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],
f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,1,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,1,?,?],f32>
return %0 : !torch.vtensor<[1,1,?,?],f32>
}

// -----
Expand Down

0 comments on commit 7aec80b

Please sign in to comment.