Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Nov 6, 2023
1 parent 4b9db99 commit e16f13e
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 9 deletions.
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6419,6 +6419,30 @@ def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [
}];
}

def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::pixel_shuffle : (Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$upscale_factor
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenPixelShuffleOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenPixelShuffleOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenMovedimIntOp : Torch_Op<"aten.movedim.int", [
AllowsTypeRefinement,
ReadOnly
Expand Down
10 changes: 10 additions & 0 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
xIndices.assign(llvm::to_vector(llvm::seq<int64_t>(0, xDims.size())));
return success();
}

return failure();
}

Expand Down Expand Up @@ -350,6 +351,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
return sizeInt == -1;
return false;
}) > 1) {

llvm::errs() << "In lowering to linalg. More than one element in size "
"list is -1\n";
return rewriter.notifyMatchFailure(
op, "at most one element in size list is allowed to be -1");
}
Expand Down Expand Up @@ -454,6 +458,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
if (assumedDynamicDimNotSplit && inputShapeSlice.size() == 1 &&
outputShapeSlice.size() != 1 &&
inputShapeSlice[0] == kUnknownSize) {

llvm::errs() << "An ambiguous expand from dynamic" << '\n';

return rewriter.notifyMatchFailure(
op, "found ambiguous expand of dynamic input sizes "
"(e.g. [-1, -1] -> [-1, -1, -1])");
Expand Down Expand Up @@ -491,6 +498,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
outputSliceIndices.push_back(0);
assumedDynamicDimNotSplit = true;
} else {

// This case is being hit.
llvm::errs() << "Unhandled case of expand/collapse" << '\n';
return rewriter.notifyMatchFailure(
op, "unimplemented: found unhandled case of expansion/collapse "
"in `aten.view`");
Expand Down
181 changes: 173 additions & 8 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
keepDimCst, dtype);
}

// Redunction function to calculate max along given `dim`.
// Reduction function to calculate max along given `dim`.
static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
Operation *op, Value input, Value dim,
bool keepDim) {
Expand Down Expand Up @@ -211,6 +211,7 @@ class DecomposeAtenAmaxOp : public OpRewritePattern<AtenAmaxOp> {
Location loc = op.getLoc();
SmallVector<int64_t, 4> dims;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims)))

return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");

Expand All @@ -227,8 +228,7 @@ class DecomposeAtenAmaxOp : public OpRewritePattern<AtenAmaxOp> {
}
// For every dimension included in `dim` of the op, iterated over in
// reverse order, we create a call to aten.max.dim.
std::sort(dims.begin(), dims.end());
std::reverse(dims.begin(), dims.end());
std::sort(dims.rbegin(), dims.rend());
for (int64_t dimInt : dims) {
int64_t inputRank = inputTy.getSizes().size();
dimInt = toPositiveDim(dimInt, inputRank);
Expand All @@ -255,6 +255,7 @@ class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
Location loc = op.getLoc();
Value self = op.getSelf();
MLIRContext *context = op.getContext();

std::optional<unsigned> maybeRank = getTensorRank(self);
if (!maybeRank)
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
Expand Down Expand Up @@ -386,9 +387,10 @@ class DecomposeAtenGluOp : public OpRewritePattern<AtenGluOp> {

Value remainder = rewriter.create<AtenRemainderIntOp>(loc, dimSize, two);
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, remainder, zero);

rewriter.create<RuntimeAssertOp>(
loc, eqOrNot,
rewriter.getStringAttr("AtenGluOp's dim size must be multiply of 2"));
rewriter.getStringAttr("AtenGluOp's dim size must be multiple of 2"));

Value splitLength = rewriter.create<AtenFloordivIntOp>(loc, dimSize, two);
Value a = rewriter.create<AtenNarrowOp>(loc, outputTy, self, dim, zero,
Expand Down Expand Up @@ -443,6 +445,7 @@ class DecomposeAtenEyeMOp : public OpRewritePattern<AtenEyeMOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
int64_t n;

if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
return rewriter.notifyMatchFailure(op,
"unimplemented: n must be constant");
Expand Down Expand Up @@ -1092,9 +1095,171 @@ class DecomposeAtenMvOp : public OpRewritePattern<AtenMvOp> {
};
} // namespace

// Decompose aten.pixel_shuffle into: aten.permute and aten.reshape operations:
//
// If input is a tensor of shape (*leading_dims, C*r*r, H, W), where
// leading_dims is of size N, then
// X = pixel_shuffle(input, upscale_factor)
//
// gets replaced with
// A = input.reshape(*leading_dims, C, r, r, H, W)
// B = A.permute(0, ..., N, N+3, N+1, N+4, N+2)
// X = B.reshape(*leading_dims, C, r*H, r*W)
//
// 'r' above is referred to as the 'upscale factor' or just 'factor' below.
namespace {
class DecomposeAtenPixelShuffleOp
: public OpRewritePattern<AtenPixelShuffleOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenPixelShuffleOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
Value inValue = op.getSelf();
auto inType = inValue.getType().cast<BaseTensorType>();
auto maybeSizes = inType.getOptionalSizes();
if (!maybeSizes) {
return rewriter.notifyMatchFailure(
op, "Expected input tensor to have known rank.");
}
auto inShape = maybeSizes.value();
auto inRank = inShape.size();

// At least 3 dimensions are needed
// (case when leading_dims is empty).
if (inRank < 3)
return rewriter.notifyMatchFailure(
op, "Expected input tensor to have rank greater than 2.");

auto nLeadingDims = inRank - 3;

// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
// of 'create': if the dimension size is known, then the AtenSizeIntOp is
// folded to a ConstantOp.
auto getDimSize = [&](uint64_t i) -> Value {
Value dim =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
};

auto inC = getDimSize(inRank - 3);
auto inH = getDimSize(inRank - 2);
auto inW = getDimSize(inRank - 1);

auto factor = op.getUpscaleFactor();


Value factorSquared =
rewriter.createOrFold<AtenMulIntOp>(loc, factor, factor);
Value outC =
rewriter.createOrFold<AtenFloordivIntOp>(loc, inC, factorSquared);

Value outH = rewriter.createOrFold<AtenMulIntOp>(loc, inH, factor);
Value outW = rewriter.createOrFold<AtenMulIntOp>(loc, inW, factor);

// Shape of 'A' in the comment at the top
SmallVector<Value> prePermuteShape;
prePermuteShape.reserve(nLeadingDims + 5);

// Shape of 'B' in the comment at the top.
SmallVector<Value> postPermuteShape;
postPermuteShape.reserve(nLeadingDims + 5);

SmallVector<Value> outShape;
outShape.reserve(nLeadingDims + 3);

SmallVector<Value> permutation;
permutation.reserve(nLeadingDims + 5);

for (unsigned i = 0; i < nLeadingDims; ++i) {
auto dimensionAttr = rewriter.getI64IntegerAttr(i);
Value dimensionValue = rewriter.create<ConstantIntOp>(loc, dimensionAttr);
Value leadingDimSize =
rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dimensionValue);
prePermuteShape.push_back(leadingDimSize);
postPermuteShape.push_back(leadingDimSize);
outShape.push_back(leadingDimSize);
permutation.push_back(dimensionValue);
}

const auto inOptionalDType = inType.getOptionalDtype();

auto getTypeFromShape = [inOptionalDType](auto &&vals) {
// Get a vector of integers from a vector of Values.
auto getIntShape = [](auto &&vals) {
SmallVector<int64_t> shape;
shape.reserve(vals.size());
for (auto v : vals) {
int64_t cst_val;
if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
shape.push_back(cst_val);
} else {
shape.push_back(kUnknownSize);
}
}
return shape;
};

const auto intShape = getIntShape(vals);
return ValueTensorType::get(vals[0].getContext(),
llvm::ArrayRef(intShape), inOptionalDType);
};

prePermuteShape.insert(prePermuteShape.end(),
{outC, factor, factor, inH, inW});

postPermuteShape.insert(postPermuteShape.end(),
{outC, inH, factor, inW, factor});

outShape.insert(outShape.end(), {outC, outH, outW});

SmallVector<uint64_t> permutationTail{0, 3, 1, 4, 2};
for (uint64_t d : permutationTail) {
permutation.push_back(rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(nLeadingDims + d)));
}

auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));

Value shapeA =
rewriter.create<PrimListConstructOp>(loc, listType, prePermuteShape);

Value A = rewriter.create<AtenReshapeOp>(
loc, getTypeFromShape(prePermuteShape), inValue, shapeA);

// llvm::errs() << "new reshape (A) op " << A << " which has shape " << shapeA
// << "\n\n";

Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
permutation);

Value B = rewriter.create<AtenPermuteOp>(
loc, getTypeFromShape(postPermuteShape), A, permuteDimsOrder);

Value outShapeList =
rewriter.create<PrimListConstructOp>(loc, listType, outShape);

// TODO(jn) figure out why the type of the returned value must be made undefined. Is this because I need to implement type inference somewhere else? If I don't do this, I get error about function return type disagreeing.
auto finalShape = getTypeFromShape(outShape);
auto finalType = finalShape.getWithSizesAndDtype({}, {});

Value out =
rewriter.createOrFold<AtenReshapeOp>(loc, finalType, B, outShapeList);

rewriter.replaceAllUsesWith(op.getResult(), out);
rewriter.eraseOp(op);


return success();
}
};
} // namespace

// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
Value input) {
static Value
getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) {
BaseTensorType inputType = input.getType().cast<BaseTensorType>();

Value relu = rewriter.create<AtenReluOp>(loc, inputType, input);
Expand Down Expand Up @@ -4717,8 +4882,7 @@ class DecomposePrimsSqueezeOp : public OpRewritePattern<PrimsSqueezeOp> {
return rewriter.notifyMatchFailure(
op, "all dimensions must be constant ints");

std::sort(dimensions.begin(), dimensions.end());
std::reverse(dimensions.begin(), dimensions.end());
std::sort(dimensions.rbegin(), dimensions.rend());

if (dimensions.size() == 0) {
rewriter.replaceOp(op, input);
Expand Down Expand Up @@ -5463,6 +5627,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenNormScalarOptDimOp>();
target.addIllegalOp<AtenSelectIntOp>();
target.addIllegalOp<AtenMvOp>();
target.addIllegalOp<AtenPixelShuffleOp>();
target.addIllegalOp<AtenTOp>();
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ bool Torch::isViewLikeOp(Operation *op) {
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
AtenViewAsComplexOp, AtenViewAsRealOp>(op);
AtenViewAsComplexOp, AtenViewAsRealOp, AtenPixelShuffleOp>(op);
}

Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
emit("aten::permute : (Tensor, int[]) -> (Tensor)")
emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)")
emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)")
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")
Expand Down
31 changes: 31 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,37 @@ def PermuteModule_basic(module, tu: TestUtils):
# ==============================================================================


class PixelShuffleModuleStaticShape(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
def forward(self, x):
return torch.ops.aten.pixel_shuffle(x, 3)

@register_test_case(module_factory=lambda: PixelShuffleModuleStaticShape())
def PixelShuffleModuleStaticShape_basic(module, tu: TestUtils):
module.forward(tu.rand(3,18,2,2))


class PixelShuffleModuleDynShape(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1,1,1], torch.float32, True)])
def forward(self, x):
return torch.ops.aten.pixel_shuffle(x, 2)

@register_test_case(module_factory=lambda: PixelShuffleModuleDynShape())
def PixelShuffleModuleDynShape_basic(module, tu: TestUtils):
module.forward(tu.rand(4,1,1))


# ==============================================================================


class PermuteNegativeIndexModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit e16f13e

Please sign in to comment.