Skip to content

Commit

Permalink
update, new test
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Nov 5, 2023
1 parent d81a324 commit 805d502
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 77 deletions.
10 changes: 10 additions & 0 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
xIndices.assign(llvm::to_vector(llvm::seq<int64_t>(0, xDims.size())));
return success();
}

return failure();
}

Expand Down Expand Up @@ -350,6 +351,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
return sizeInt == -1;
return false;
}) > 1) {

llvm::errs() << "In lowering to linalg. More than one element in size "
"list is -1\n";
return rewriter.notifyMatchFailure(
op, "at most one element in size list is allowed to be -1");
}
Expand Down Expand Up @@ -441,6 +445,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
if (assumedDynamicDimNotSplit && inputShapeSlice.size() == 1 &&
outputShapeSlice.size() != 1 &&
inputShapeSlice[0] == kUnknownSize) {

llvm::errs() << "An ambiguous expand from dynamic" << '\n';

return rewriter.notifyMatchFailure(
op, "found ambiguous expand of dynamic input sizes "
"(e.g. [-1, -1] -> [-1, -1, -1])");
Expand Down Expand Up @@ -478,6 +485,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
outputSliceIndices.push_back(0);
assumedDynamicDimNotSplit = true;
} else {

// This case is being hit.
llvm::errs() << "Unhandled case of expand/collapse" << '\n';
return rewriter.notifyMatchFailure(
op, "unimplemented: found unhandled case of expansion/collapse "
"in `aten.view`");
Expand Down
167 changes: 97 additions & 70 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>

using namespace mlir;
Expand Down Expand Up @@ -230,7 +229,6 @@ class DecomposeAtenAmaxOp : public OpRewritePattern<AtenAmaxOp> {
// For every dimension included in `dim` of the op, iterated over in
// reverse order, we create a call to aten.max.dim.
std::sort(dims.rbegin(), dims.rend());
// std::reverse(dims.begin(), dims.end());
for (int64_t dimInt : dims) {
int64_t inputRank = inputTy.getSizes().size();
dimInt = toPositiveDim(dimInt, inputRank);
Expand Down Expand Up @@ -390,7 +388,6 @@ class DecomposeAtenGluOp : public OpRewritePattern<AtenGluOp> {
Value remainder = rewriter.create<AtenRemainderIntOp>(loc, dimSize, two);
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, remainder, zero);

// (jn) you can insert a runtime assert op?
rewriter.create<RuntimeAssertOp>(
loc, eqOrNot,
rewriter.getStringAttr("AtenGluOp's dim size must be multiple of 2"));
Expand Down Expand Up @@ -449,7 +446,6 @@ class DecomposeAtenEyeMOp : public OpRewritePattern<AtenEyeMOp> {
Location loc = op.getLoc();
int64_t n;

// TODO(jn) use something like this?
if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
return rewriter.notifyMatchFailure(op,
"unimplemented: n must be constant");
Expand Down Expand Up @@ -1079,7 +1075,18 @@ class DecomposeAtenMvOp : public OpRewritePattern<AtenMvOp> {
};
} // namespace

// Decompose aten.mv into: aten.matmul.
// Decompose aten.pixel_shuffle into: aten.permute and aten.reshape operations:
//
// If input is a tensor of shape (*leading_dims, C*r*r, H, W), where
// leading_dims is of size N, then
// X = pixel_shuffle(input, upscale_factor)
//
// gets replaced with
// A = input.reshape(*leading_dims, C, r, r, H, W)
// B = A.permute(0, ..., N, N+3, N+1, N+4, N+2)
// X = B.reshape(*leading_dims, C, r*H, r*W)
//
// 'r' above is referred to as the 'upscale factor' or just 'factor' below.
namespace {
class DecomposeAtenPixelShuffleOp
: public OpRewritePattern<AtenPixelShuffleOp> {
Expand All @@ -1091,124 +1098,145 @@ class DecomposeAtenPixelShuffleOp
Location loc = op.getLoc();
Value inValue = op.getSelf();
auto inType = inValue.getType().cast<BaseTensorType>();
auto maybeSizes = inType.getOptionalSizes();
if (!maybeSizes) {
return rewriter.notifyMatchFailure(
op, "Expected input tensor to have known rank.");
}
auto inShape = maybeSizes.value();
auto inRank = inShape.size();

auto sizes = inType.getSizes();
auto selfRank = sizes.size();
// At least 3 dimensions are needed
// (case when leading_dims is empty).
if (inRank < 3)
return rewriter.notifyMatchFailure(
op, "Expected input tensor to have rank greater than 2.");

// At least 3 dimensions are needed:
// (*, C*r*r, H, W) -> (*, C, rH, rW)
if (sizes.size() < 3)
return rewriter.notifyMatchFailure(op, "Unimplemented: rank < 3");
auto nLeadingDims = inRank - 3;

// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
// of 'create': if the dimension size is known, then the AtenSizeIntOp is
// folded to a ConstantOp.
auto getDimSize = [&](uint64_t i) -> Value {
Value dim =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));

if (sizes[i] == kUnknownSize) {
return rewriter.create<AtenSizeIntOp>(loc, inValue, dim);
}
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(sizes[i]));
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
};

auto inW = getDimSize(selfRank - 1);
auto inH = getDimSize(selfRank - 2);
auto inC = getDimSize(selfRank - 3);
auto inC = getDimSize(inRank - 3);
auto inH = getDimSize(inRank - 2);
auto inW = getDimSize(inRank - 1);

auto factor = op.getUpscaleFactor();

auto upscaleFactor = op.getUpscaleFactor();

Value outW = rewriter.createOrFold<AtenMulIntOp>(loc, inW, upscaleFactor);
Value outH = rewriter.createOrFold<AtenMulIntOp>(loc, inH, upscaleFactor);
auto upscaleFactorSquared =
rewriter.createOrFold<AtenMulIntOp>(loc, upscaleFactor, upscaleFactor);
Value factorSquared =
rewriter.createOrFold<AtenMulIntOp>(loc, factor, factor);
Value outC =
rewriter.createOrFold<AtenFloordivIntOp>(loc, inC, upscaleFactorSquared);
rewriter.createOrFold<AtenFloordivIntOp>(loc, inC, factorSquared);

Value outH = rewriter.createOrFold<AtenMulIntOp>(loc, inH, factor);
Value outW = rewriter.createOrFold<AtenMulIntOp>(loc, inW, factor);

// Shape of 'A' in the comment at the top
SmallVector<Value> prePermuteShape;
prePermuteShape.reserve(nLeadingDims + 5);

// Shape of 'B' in the comment at the top.
SmallVector<Value> postPermuteShape;
postPermuteShape.reserve(nLeadingDims + 5);

SmallVector<Value> outShape;
outShape.reserve(nLeadingDims + 3);

SmallVector<Value> permutation;
permutation.reserve(nLeadingDims + 5);

// process the leading '*' dimensions.
for (unsigned i = 0; i < selfRank - 3; ++i) {
for (unsigned i = 0; i < nLeadingDims; ++i) {
auto dimensionAttr = rewriter.getI64IntegerAttr(i);
auto dimension = rewriter.create<ConstantIntOp>(loc, dimensionAttr);
Value leadingDimSize =
rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dimension);
Value dimensionValue = rewriter.create<ConstantIntOp>(loc, dimensionAttr);
Value leadingDimSize =
rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dimensionValue);
prePermuteShape.push_back(leadingDimSize);
postPermuteShape.push_back(leadingDimSize);
outShape.push_back(leadingDimSize);
permutation.push_back(dimension);
permutation.push_back(dimensionValue);
}

auto getIntShape = [](auto &&vals) {
SmallVector<int64_t> shape;
for (auto v : vals) {
int64_t cst_val;
if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
shape.push_back(cst_val);
} else {
shape.push_back(kUnknownSize);
const auto inOptionalDType = inType.getOptionalDtype();

auto getTypeFromShape = [inOptionalDType](auto &&vals) {
// Get a vector of integers from a vector of Values.
auto getIntShape = [](auto &&vals) {
SmallVector<int64_t> shape;
shape.reserve(vals.size());
for (auto v : vals) {
int64_t cst_val;
if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
shape.push_back(cst_val);
} else {
shape.push_back(kUnknownSize);
}
}
}
return shape;
};
return shape;
};

auto getTypeFromShape = [&getIntShape, &inType](auto &&vals) {
auto intShape = getIntShape(vals);
const auto intShape = getIntShape(vals);
return ValueTensorType::get(vals[0].getContext(),
llvm::ArrayRef(intShape),
inType.getOptionalDtype());
llvm::ArrayRef(intShape), inOptionalDType);
};

prePermuteShape.insert(prePermuteShape.end(),
{outC, upscaleFactor, upscaleFactor, inH, inW});
{outC, factor, factor, inH, inW});

postPermuteShape.insert(postPermuteShape.end(),
{outC, inH, upscaleFactor, inW, upscaleFactor});
{outC, inH, factor, inW, factor});

outShape.insert(outShape.end(), {outC, outH, outW});


// TODO(jn) : improved verification on permute when static shape
auto upRank = prePermuteShape.size();
for (uint64_t d :
{upRank - 5, upRank - 2, upRank - 4, upRank - 1, upRank - 3}) {
permutation.push_back(
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(d)));
SmallVector<uint64_t> permutationTail{0, 3, 1, 4, 2};
for (uint64_t d : permutationTail) {
permutation.push_back(rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(nLeadingDims + d)));
}

auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value shape0 =

Value shapeA =
rewriter.create<PrimListConstructOp>(loc, listType, prePermuteShape);

auto reshape0 = rewriter.create<AtenViewOp>(
loc, getTypeFromShape(prePermuteShape), inValue, shape0);
Value A = rewriter.create<AtenReshapeOp>(
loc, getTypeFromShape(prePermuteShape), inValue, shapeA);

// llvm::errs() << "new reshape (A) op " << A << " which has shape " << shapeA
// << "\n\n";

Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
permutation);

auto perm0 = rewriter.create<AtenPermuteOp>(
loc, getTypeFromShape(postPermuteShape), reshape0, permuteDimsOrder);
Value B = rewriter.create<AtenPermuteOp>(
loc, getTypeFromShape(postPermuteShape), A, permuteDimsOrder);

Value shape1 =
Value outShapeList =
rewriter.create<PrimListConstructOp>(loc, listType, outShape);

// TODO(jn) figure out why the type of the returned value must be made undefined. Is this because I need to implement type inference somewhere else? If I don't do this, I get error about function return type disagreeing.
auto finalShape = getTypeFromShape(outShape);
auto finalType = finalShape.getWithSizesAndDtype({}, {}); // finalShape.getOptionalDtype());
rewriter.replaceOpWithNewOp<AtenViewOp>(op, finalType,
perm0, shape1);
auto finalType = finalShape.getWithSizesAndDtype({}, {});

Value out =
rewriter.createOrFold<AtenReshapeOp>(loc, finalType, B, outShapeList);

rewriter.replaceAllUsesWith(op.getResult(), out);
rewriter.eraseOp(op);


return success();
}
};
} // namespace




// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
static Value
getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) {
Expand Down Expand Up @@ -4834,8 +4862,7 @@ class DecomposePrimsSqueezeOp : public OpRewritePattern<PrimsSqueezeOp> {
return rewriter.notifyMatchFailure(
op, "all dimensions must be constant ints");

std::sort(dimensions.begin(), dimensions.end());
std::reverse(dimensions.begin(), dimensions.end());
std::sort(dimensions.rbegin(), dimensions.rend());

if (dimensions.size() == 0) {
rewriter.replaceOp(op, input);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ bool Torch::isViewLikeOp(Operation *op) {
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
AtenViewAsComplexOp, AtenViewAsRealOp>(op);
AtenViewAsComplexOp, AtenViewAsRealOp, AtenPixelShuffleOp>(op);
}

Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
Expand Down
24 changes: 18 additions & 6 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,20 +595,32 @@ def PermuteModule_basic(module, tu: TestUtils):
# ==============================================================================


class PixelShuffleModule(torch.nn.Module):

class PixelShuffleModuleStaticShape(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([1, 9, 2, 2], torch.float32, True)])
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
def forward(self, x):
return torch.ops.aten.pixel_shuffle(x, 3)

@register_test_case(module_factory=lambda: PixelShuffleModuleStaticShape())
def PixelShuffleModuleStaticShape_basic(module, tu: TestUtils):
module.forward(tu.rand(3,18,2,2))


class PixelShuffleModuleDynShape(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1,1,1], torch.float32, True)])
def forward(self, x):
return torch.ops.aten.pixel_shuffle(x, 2)

@register_test_case(module_factory=lambda: PixelShuffleModule())
def PixelShuffleModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1,9,2,2))
@register_test_case(module_factory=lambda: PixelShuffleModuleDynShape())
def PixelShuffleModuleDynShape_basic(module, tu: TestUtils):
module.forward(tu.rand(4,1,1))


# ==============================================================================
Expand Down

0 comments on commit 805d502

Please sign in to comment.