Skip to content

Commit

Permalink
[MLIR][TORCH] Add E2E support for aten.upsample_nearest2d_backward.ve…
Browse files Browse the repository at this point in the history
…c op

Signed-Off By: Vivek Khandelwal<[email protected]>
  • Loading branch information
vivekkhandelwal1 committed Nov 4, 2022
1 parent db5a496 commit fedf8c0
Show file tree
Hide file tree
Showing 10 changed files with 284 additions and 3 deletions.
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,4 +621,6 @@
"Fill_TensorFloat32WithFloat32_basic",
"Fill_TensorFloat32WithFloat64_basic",
"Fill_TensorFloat32WithInt64_basic",
"UpSampleNearest2dBackwardVec_basic",
"UpSampleNearest2dBackwardOutputSizeNone_basic",
}
4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ SmallVector<Value>
castIntVectorToIndexVector(OpBuilder &b, Location loc,
SmallVectorImpl<Value> &intValues);

SmallVector<Value>
castIndexVectorToInt64Vector(OpBuilder &b, Location loc,
SmallVectorImpl<Value> &indexValues);

Value getDimOp(OpBuilder &b, Location loc, Value v, int dim);

SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
Expand Down
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4893,6 +4893,32 @@ def Torch_AtenMseLossOp : Torch_Op<"aten.mse_loss", [
}];
}

def Torch_AtenUpsampleNearest2dBackwardVecOp : Torch_Op<"aten.upsample_nearest2d_backward.vec", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::upsample_nearest2d_backward.vec : (Tensor, int[]?, int[], float[]?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchOptionalListOfTorchIntType:$output_size,
AnyTorchListOfTorchIntType:$input_size,
AnyTorchOptionalListOfTorchFloatType:$scale_factors
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUpsampleNearest2dBackwardVecOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenUpsampleNearest2dBackwardVecOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
186 changes: 186 additions & 0 deletions lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,190 @@ class ConvertAtenUpsampleNearest2dVecOp
};
} // namespace

static Value getGradOutputValue(OpBuilder &builder, Location loc,
Value gradOutput, Type gradOutputElemType,
Value numBatch, Value numChannel,
Value inputIndexH, Value inputIndexW,
Value kernelIndexH, Value kernelIndexW,
SmallVector<Value> &gradOutputSizeIndexValues,
SmallVector<Value, 2> &scaleFactorsIntValues) {
Value constantOne = builder.create<arith::ConstantIndexOp>(loc, 1);

Value outputIndexH = builder.create<arith::MulIOp>(
loc, inputIndexH, castIntToIndex(builder, loc, scaleFactorsIntValues[0]));
outputIndexH = builder.create<arith::AddIOp>(loc, outputIndexH, kernelIndexH);

Value outputIndexW = builder.create<arith::MulIOp>(
loc, inputIndexW, castIntToIndex(builder, loc, scaleFactorsIntValues[1]));
outputIndexW = builder.create<arith::AddIOp>(loc, outputIndexW, kernelIndexW);

// Handling corner cases.
Value gradOutputHMinusOne = builder.create<arith::SubIOp>(
loc, gradOutputSizeIndexValues[2], constantOne);
Value predH = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, outputIndexH, gradOutputHMinusOne);
outputIndexH = builder.create<arith::SelectOp>(loc, predH, outputIndexH,
gradOutputHMinusOne);

Value gradOutputWMinusOne = builder.create<arith::SubIOp>(
loc, gradOutputSizeIndexValues[3], constantOne);
Value predW = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, outputIndexW, gradOutputWMinusOne);
outputIndexW = builder.create<arith::SelectOp>(loc, predW, outputIndexW,
gradOutputWMinusOne);

Value gradOutputValue = builder.create<tensor::ExtractOp>(
loc, gradOutput,
ValueRange{numBatch, numChannel, outputIndexH, outputIndexW});
Value constantZero =
builder.create<arith::ConstantOp>(loc, builder.getF32FloatAttr(0.0));
Value pred = builder.create<arith::AndIOp>(loc, predH, predW);
Value result = builder.create<arith::SelectOp>(
loc, pred, gradOutputValue,
convertScalarToDtype(builder, loc, constantZero, gradOutputElemType));

return result;
}

// The implementation of the `aten.upsample_nearest2d_backward.vec` op's
// lowering is as follows:
// gradOutput: Tensor of size [n, c, oh, ow]
// outTensor: Tensor of size [n, c, ih, iw], initialized with zero
// kh = ceil(oh/ih), kw = ceil(ow/iw)
//
// for i in range(n):
// for j in range(c):
// for p in range(ih):
// for q in range(iw):
// for x in range(kh):
// for y in range(kw):
// outTensor[i, j, p, q] += gradOutput[i, j, (p*kh)+x, (q*kw)+y]
namespace {
class ConvertAtenUpsampleNearest2dBackwardVecOp
: public OpConversionPattern<AtenUpsampleNearest2dBackwardVecOp> {

public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenUpsampleNearest2dBackwardVecOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Location loc = op->getLoc();
Value gradOutput = adaptor.grad_output();

Type resultType = getTypeConverter()->convertType(op.getResult().getType());
auto gradOutputType = gradOutput.getType().cast<RankedTensorType>();
auto gradOutputRank = gradOutputType.getRank();
Type elementType = gradOutputType.getElementType();

SmallVector<Value> gradOutputSizeIndexValues =
getTensorSizes(rewriter, loc, gradOutput);
SmallVector<Value> gradOutputSizeIntValues =
castIndexVectorToInt64Vector(rewriter, loc, gradOutputSizeIndexValues);
SmallVector<Value, 2> scaleFactorsFloatValues;

SmallVector<Value, 4> inputSizeTorchInt;
if (!getListConstructElements(op.input_size(), inputSizeTorchInt))
return rewriter.notifyMatchFailure(
op, "unimplemented: the input_size is not constructed from "
"ListConstruct");
SmallVector<Value, 4> inputSizeIntValues;
inputSizeIntValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), inputSizeTorchInt);

// The dimension at which the scaling starts.
unsigned hDimOffset = 2;

if (!op.scale_factors().getType().isa<Torch::NoneType>()) {
SmallVector<Value, 2> scaleFactorsTorchFloat;
if (!getListConstructElements(op.scale_factors(), scaleFactorsTorchFloat))
return rewriter.notifyMatchFailure(
op, "unimplemented: the scale_factors is not constructed from "
"ListConstruct");
scaleFactorsFloatValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), scaleFactorsTorchFloat);
} else {
for (unsigned i = hDimOffset; i < gradOutputRank; i++) {
auto scaleFactorVal = rewriter.create<arith::DivFOp>(
loc,
convertScalarToDtype(rewriter, loc, gradOutputSizeIntValues[i],
mlir::Float32Type::get(op->getContext())),
convertScalarToDtype(rewriter, loc, inputSizeIntValues[i],
mlir::Float32Type::get(op->getContext())));
scaleFactorsFloatValues.push_back(scaleFactorVal);
}
}

SmallVector<Value, 2> scaleFactorsIntValues;
for (auto v : scaleFactorsFloatValues)
scaleFactorsIntValues.push_back(convertScalarToDtype(
rewriter, loc, rewriter.create<math::CeilOp>(loc, v),
mlir::IntegerType::get(op->getContext(), 64)));

Value outTensor = createZeroInitTensor(
rewriter, loc,
castIntVectorToIndexVector(rewriter, loc, inputSizeIntValues),
elementType);

Value kernelTensor = rewriter.create<tensor::EmptyOp>(
loc,
getAsOpFoldResult(
castIntVectorToIndexVector(rewriter, loc, scaleFactorsIntValues)),
elementType);
unsigned kernelRank = scaleFactorsIntValues.size();

SmallVector<AffineExpr> affineExprs;
for (unsigned i = 0; i < gradOutputRank; i++)
affineExprs.push_back(rewriter.getAffineDimExpr(i));

AffineMap outputMap =
AffineMap::get(gradOutputRank + kernelRank,
/*symbolCount=*/0, affineExprs, op->getContext());

affineExprs.clear();
for (unsigned i = gradOutputRank; i < gradOutputRank + kernelRank; i++)
affineExprs.push_back(rewriter.getAffineDimExpr(i));

AffineMap kernelMap =
AffineMap::get(gradOutputRank + kernelRank,
/*symbolCount=*/0, affineExprs, op->getContext());

SmallVector<AffineMap> indexingMaps{kernelMap, outputMap};
SmallVector<StringRef> iteratorTypes(gradOutputRank,
getParallelIteratorTypeName());
iteratorTypes.push_back(getReductionIteratorTypeName());
iteratorTypes.push_back(getReductionIteratorTypeName());

Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, outTensor.getType(), ValueRange{kernelTensor},
ValueRange{outTensor},
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value n = rewriter.create<linalg::IndexOp>(loc, 0);
Value c = rewriter.create<linalg::IndexOp>(loc, 1);
Value ih = rewriter.create<linalg::IndexOp>(loc, 2);
Value iw = rewriter.create<linalg::IndexOp>(loc, 3);
Value kh = rewriter.create<linalg::IndexOp>(loc, 4);
Value kw = rewriter.create<linalg::IndexOp>(loc, 5);
Value accValue = getGradOutputValue(
rewriter, loc, gradOutput, elementType, n, c, ih, iw, kh,
kw, gradOutputSizeIndexValues, scaleFactorsIntValues);
Value outputVal = args[1];
outputVal =
rewriter.create<arith::AddFOp>(loc, outputVal, accValue);
b.create<linalg::YieldOp>(loc, outputVal);
})
->getResult(0);

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, finalRes);
return success();
}
};
} // namespace

void mlir::torch::torch_to_linalg::
populateIndirectDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
Expand All @@ -913,4 +1097,6 @@ void mlir::torch::torch_to_linalg::
patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context);
target.addIllegalOp<AtenUpsampleNearest2dVecOp>();
patterns.add<ConvertAtenUpsampleNearest2dVecOp>(typeConverter, context);
target.addIllegalOp<AtenUpsampleNearest2dBackwardVecOp>();
patterns.add<ConvertAtenUpsampleNearest2dBackwardVecOp>(typeConverter, context);
}
9 changes: 9 additions & 0 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ castIntVectorToIndexVector(OpBuilder &b, Location loc,
return indexValues;
}

SmallVector<Value>
castIndexVectorToInt64Vector(OpBuilder &b, Location loc,
SmallVectorImpl<Value> &indexValues) {
SmallVector<Value> intValues;
for (Value v : indexValues)
intValues.push_back(castIndexToInt64(b, loc, v));
return intValues;
}

Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
return b.createOrFold<tensor::DimOp>(loc, v, dim);
}
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,8 +700,8 @@ void TypeAnalysis::visitOperation(Operation *op,
AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp,
AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp,
AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp,
AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp,
AtenFillTensorOp>(op)) {
AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp,
AtenUpsampleNearest2dBackwardVecOp>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}

Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6076,6 +6076,9 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward.vec\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<float>>) -> !torch.list<int> {\n"
" return %arg2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.optional<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,9 @@ def aten〇max_pool2d_with_indices(self: List[int], kernel_size: List[int], stri
def aten〇max_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]:
return self

def aten〇upsample_nearest2d_backward〇vec(grad_output: List[int], output_size: Optional[List[int]], input_size: List[int], scale_factors: Optional[List[float]]) -> List[int]:
return input_size

# TODO: This should be upstreamed.
# See https://github.com/pytorch/pytorch/pull/76889 for an example.
def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool, divisor_override: Optional[int]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
emit("aten::upsample_nearest2d_backward.vec : (Tensor, int[]?, int[], float[]?) -> (Tensor)")

# Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
Expand Down
49 changes: 48 additions & 1 deletion python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3021,4 +3021,51 @@ def forward(self, x):

@register_test_case(module_factory=lambda: SingleTensorTupleReturn())
def SingleTensorTupleReturn_basic(module, tu: TestUtils):
module.forward(torch.randn(2, 4))
module.forward(torch.randn(2, 4))


# ==============================================================================


class UpSampleNearest2dBackwardVec(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, input):
return torch.ops.aten.upsample_nearest2d_backward(input,
output_size=[4, 8],
input_size=[1, 1, 2, 3],
scale_factors=None)


@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardVec())
def UpSampleNearest2dBackwardVec_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 4, 8))


class UpSampleNearest2dBackwardOutputSizeNone(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float64, True),
])
def forward(self, input):
return torch.ops.aten.upsample_nearest2d_backward(input,
output_size=None,
input_size=[1, 1, 2, 3],
scale_factors=[3.0, 4.0])


@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardOutputSizeNone())
def UpSampleNearest2dBackwardOutputSizeNone_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))

0 comments on commit fedf8c0

Please sign in to comment.