Skip to content

Commit

Permalink
[Torch] enhance fold of aten.squeeze.dim (#3558)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu authored Jul 24, 2024
1 parent d1e172f commit aad1604
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 22 deletions.
30 changes: 26 additions & 4 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,17 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) {
return FloatAttr::get(Float64Type::get(context), value);
}

static DenseElementsAttr reshapeDenseElementsAttr(DenseElementsAttr attr,
ShapedType newType) {
// TODO: DenseElementsAttr::reshape is broken for bool splats.
// Once that ticket is fixed, we can remove this conditional.
if (attr.isSplat() && newType.getElementType().isInteger(/*width=*/1)) {
auto splatValue = attr.getValues<bool>()[0];
return DenseElementsAttr::get(newType, {splatValue});
}
return attr.reshape(newType);
}

static Value getScalarIntValue(Value input, Location loc,
PatternRewriter &rewriter) {
auto inputType = input.getType();
Expand Down Expand Up @@ -798,11 +809,22 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//

OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
if (getOperand(0).getType() != getResult().getType())
auto inType = dyn_cast<ValueTensorType>(getOperand(0).getType());
auto outType = dyn_cast<ValueTensorType>(getResult().getType());
if (!inType || !outType || !inType.areAllSizesKnown() ||
!outType.areAllSizesKnown() || !inType.hasDtype() ||
!outType.hasDtype()) {
return nullptr;
if (auto tensorType = dyn_cast<BaseTensorType>(getOperand(0).getType())) {
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
return getOperand(0);
}

if (inType == outType) {
return getOperand(0);
}

DenseElementsAttr input =
dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
if (input) {
return reshapeDenseElementsAttr(input, outType.toBuiltinTensor());
}
return nullptr;
}
Expand Down
26 changes: 18 additions & 8 deletions test/Conversion/TorchToStablehlo/view_like.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -379,15 +379,25 @@ func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vt
// -----

// CHECK-LABEL: func.func @torch.aten.squeeze.dim$0$static(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> {
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,1,2],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32>
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[T1:.*]] = torch_c.from_builtin_tensor %[[T0]] : tensor<2x1x2x1x2xf32> -> !torch.vtensor<[2,1,2,1,2],f32>
// CHECK: return %[[T1]] : !torch.vtensor<[2,1,2,1,2],f32>
func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,1,2,1,2],f32>
return %0 : !torch.vtensor<[2,1,2,1,2],f32>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<2x1x2x1x2xf32>
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<2x1x2x1x2xf32>
// CHECK: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C3:.*]] : tensor<2x1x2x1x2xf32>
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32>
// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<4xindex>
// CHECK: %[[T1:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<4xindex>) -> tensor<2x2x1x2xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<2x2x1x2xf32> -> !torch.vtensor<[2,2,1,2],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[2,2,1,2],f32>
func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,1,2],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,2,1,2],f32>
return %0 : !torch.vtensor<[2,2,1,2],f32>
}

// -----
Expand Down
67 changes: 57 additions & 10 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1507,20 +1507,67 @@ func.func @torch.aten.Float.Tensor(%arg0: !torch.float) -> !torch.float {
}

// CHECK-LABEL: func.func @torch.aten.squeeze$zero_rank(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>
func.func @torch.aten.squeeze$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
%0 = torch.aten.squeeze %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32>
return %0 : !torch.tensor<[],f32>
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[],f32>
func.func @torch.aten.squeeze$zero_rank(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
%0 = torch.aten.squeeze %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}

// CHECK-LABEL: func.func @torch.aten.squeeze.dim$zero_rank(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32>
func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> {
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[],f32>
func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
%int0 = torch.constant.int 0
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor<[],f32>
return %0 : !torch.tensor<[],f32>
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}

// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst() -> !torch.vtensor<[2],si64> {
// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense<[127, 128]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>
// CHECK-NEXT: return %[[CST]]
func.func @torch.aten.squeeze.dim$cst() -> !torch.vtensor<[2],si64> {
%int1 = torch.constant.int 1
%0 = torch.vtensor.literal(dense<[[127], [128]]> : tensor<2x1xsi64>) : !torch.vtensor<[2,1],si64>
%1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2],si64>
return %1 : !torch.vtensor<[2],si64>
}

// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst_i1() -> !torch.vtensor<[3],i1> {
// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense<[true, false, true]> : tensor<3xi1>) : !torch.vtensor<[3],i1>
// CHECK-NEXT: return %[[CST]]
func.func @torch.aten.squeeze.dim$cst_i1() -> !torch.vtensor<[3],i1> {
%int1 = torch.constant.int 1
%0 = torch.vtensor.literal(dense<[[true], [false], [true]]> : tensor<3x1xi1>) : !torch.vtensor<[3,1],i1>
%1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[3,1],i1>, !torch.int -> !torch.vtensor<[3],i1>
return %1 : !torch.vtensor<[3],i1>
}

// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst_splat_i1() -> !torch.vtensor<[3],i1> {
// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense<true> : tensor<3xi1>) : !torch.vtensor<[3],i1>
// CHECK-NEXT: return %[[CST]]
func.func @torch.aten.squeeze.dim$cst_splat_i1() -> !torch.vtensor<[3],i1> {
%int1 = torch.constant.int 1
%0 = torch.vtensor.literal(dense<true> : tensor<3x1xi1>) : !torch.vtensor<[3,1],i1>
%1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[3,1],i1>, !torch.int -> !torch.vtensor<[3],i1>
return %1 : !torch.vtensor<[3],i1>
}

// CHECK-LABEL: func.func @torch.aten.squeeze.dim$same_shape(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,1],si64> {
// CHECK-NEXT: return %[[ARG]]
func.func @torch.aten.squeeze.dim$same_shape(%arg0: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,1],si64> {
%int0 = torch.constant.int 0
%0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2,1],si64>
return %0 : !torch.vtensor<[2,1],si64>
}

// CHECK-LABEL: func.func @torch.aten.squeeze.dim$not_fold
// CHECK: torch.aten.squeeze.dim
func.func @torch.aten.squeeze.dim$not_fold(%arg0: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2],si64> {
%int1 = torch.constant.int 1
%0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2],si64>
return %0 : !torch.vtensor<[2],si64>
}

// CHECK-LABEL: func.func @torch.aten.tensor$one_elem(
Expand Down

0 comments on commit aad1604

Please sign in to comment.