Skip to content

Commit

Permalink
Remove convolution_overrideable, convolution_backward_overrideable (l…
Browse files Browse the repository at this point in the history
…lvm#1984)

The ops `aten.convolution_overrideable` and
`aten.convolution_backward_overrideable` are currently not e2e tested
in Torch-MLIR. Moreover, there is no way to add e2e tests for them
because the ops cannot be called using the CPU backend (this also
prevents adding tested dtype functions for these ops). Since these two
ops are not expected to ever appear in PyTorch traces obtained through
standard means (pytorch/pytorch#97481),
Torch-MLIR should not have to worry about them.
  • Loading branch information
ramiro050 authored and gpetters94 committed May 8, 2023
1 parent 32e5e68 commit 1810682
Show file tree
Hide file tree
Showing 7 changed files with 3 additions and 125 deletions.
65 changes: 0 additions & 65 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4343,37 +4343,6 @@ def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [
}];
}

def Torch_AtenConvolutionOverrideableOp : Torch_Op<"aten.convolution_overrideable", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$transposed,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConvolutionOverrideableOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 9, 1);
}
void AtenConvolutionOverrideableOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 9, 1);
}
}];
}

def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -4503,40 +4472,6 @@ def Torch_AtenConvolutionBackwardOp : Torch_Op<"aten.convolution_backward", [
}];
}

def Torch_AtenConvolutionBackwardOverrideableOp : Torch_Op<"aten.convolution_backward_overrideable", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::convolution_backward_overrideable : (Tensor, Tensor, Tensor, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$transposed,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups,
AnyTorchListOfTorchBoolType:$output_mask
);
let results = (outs
AnyTorchTensorType:$grad_input,
AnyTorchTensorType:$grad_weight,
AnyTorchTensorType:$grad_bias
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConvolutionBackwardOverrideableOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 10, 3);
}
void AtenConvolutionBackwardOverrideableOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 10, 3);
}
}];
}

def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
6 changes: 0 additions & 6 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7061,12 +7061,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.conv_backwards(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.convolution_backward_overrideable\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.list<bool>) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
" %none = torch.constant.none\n"
" %0 = torch.derefine %none : !torch.none to !torch.optional<list<int>>\n"
" %1 = call @__torch__.torch.jit._shape_functions.conv_backwards(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %1 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.batch_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down
43 changes: 0 additions & 43 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1452,24 +1452,6 @@ class DecomposeAtenMaskedFillScalarOp
}
};

} // namespace
// Decompose aten.convolution_overrideable to aten.convolution op.
namespace {
class DecomposeAtenConvolutionOverrideableOp
: public OpRewritePattern<AtenConvolutionOverrideableOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenConvolutionOverrideableOp op,
PatternRewriter &rewriter) const override {

rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
op.getOutputPadding(), op.getGroups());

return success();
}
};
} // namespace

// Decompose aten._convolution-like to aten.convolution
Expand Down Expand Up @@ -1533,27 +1515,6 @@ class DecomposeAtenConvTranspose2dOp
};
} // namespace

// Decompose aten.convolution_backward_overrideable to aten.convolution_backward
// op.
namespace {
class DecomposeAtenConvolutionBackwardOverrideableOp
: public OpRewritePattern<AtenConvolutionBackwardOverrideableOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenConvolutionBackwardOverrideableOp op,
PatternRewriter &rewriter) const override {

Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
rewriter.replaceOpWithNewOp<AtenConvolutionBackwardOp>(
op, op.getResultTypes(), op.getGradOutput(), op.getInput(), op.getWeight(),
none, op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
op.getOutputPadding(), op.getGroups(), op.getOutputMask());

return success();
}
};
} // namespace

namespace {
class DecomposeAtenConvolutionBackwardOp
: public OpRewritePattern<AtenConvolutionBackwardOp> {
Expand Down Expand Up @@ -3926,8 +3887,6 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenConvolutionBackwardOverrideableOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
Expand All @@ -3949,8 +3908,6 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConvolutionOverrideableOp>(
patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);
addPatternIfTargetOpIsIllegal<
Expand Down
2 changes: 0 additions & 2 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenWhereScalarOtherOp>();
target.addIllegalOp<AtenWhereScalarSelfOp>();
target.addIllegalOp<AtenMaskedFillScalarOp>();
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
target.addIllegalOp<AtenSizeOp>();
target.addIllegalOp<AtenReshapeOp>();
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
Expand All @@ -405,7 +404,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenLayerNormOp>();
target.addIllegalOp<AtenNativeLayerNormOp>();
target.addIllegalOp<AtenNativeBatchNormOp>();
target.addIllegalOp<AtenConvolutionOverrideableOp>();
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
target.addIllegalOp<AtenConvolutionBackwardOp>();
target.addIllegalOp<AtenConv2dOp>();
Expand Down
7 changes: 3 additions & 4 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,8 @@ void TypeAnalysis::visitOperation(Operation *op,

// Promote the two dtypes assuming non-zero rank.
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, AtenMvOp, AtenConvolutionOverrideableOp,
AtenConvTranspose2dInputOp, AtenMseLossOp>(op)) {
Aten_ConvolutionOp, AtenMvOp, AtenConvTranspose2dInputOp,
AtenMseLossOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
Expand Down Expand Up @@ -845,8 +845,7 @@ void TypeAnalysis::visitOperation(Operation *op,

// 3 results take dtype from first operand.
if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp,
AtenConvolutionBackwardOp, AtenConvolutionBackwardOverrideableOp>(
op)) {
AtenConvolutionBackwardOp>(op)) {
auto self = operands[0]->getValue();
auto result0Knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -840,9 +840,6 @@ def aten〇flip〡shape(self: List[int], dims: List[int]) -> List[int]:
def aten〇convolution_backward〡shape(grad_output: List[int], input: List[int], weight: List[int], bias_sizes: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[List[int], List[int], List[int]]:
return upstream_shape_functions.conv_backwards(grad_output, input, weight, bias_sizes)

def aten〇convolution_backward_overrideable〡shape(grad_output: List[int], input: List[int], weight: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[List[int], List[int], List[int]]:
return upstream_shape_functions.conv_backwards(grad_output, input, weight, None)

def aten〇batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,10 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)")
emit("aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)")
emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
emit("aten::roll : (Tensor, int[], int[]) -> (Tensor)"),
emit("aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)")
emit("aten::convolution_backward_overrideable : (Tensor, Tensor, Tensor, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)")
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
emit(
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
Expand Down

0 comments on commit 1810682

Please sign in to comment.