Skip to content

Commit

Permalink
[tosa] Add TorchToTosa lowering for aten.arange.start_step op (#1442)
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 authored Sep 30, 2022
1 parent aa31be7 commit 9dd5ae8
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 0 deletions.
7 changes: 7 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,13 @@
"BroadcastToSameRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic",
"SliceStaticModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeDtypeFloatModule_basic",
"ArangeIntModule_basic",
"ArangeNegativeStartIntModule_basic",
"ArangeStartIntModule_basic",
"ArangeStartNegativeStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
}

LTC_XFAIL_SET = {
Expand Down
51 changes: 51 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3005,6 +3005,56 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
"unimplemented: broadcasts other than same rank or zero ranked tensor.");
}

template <>
LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
AtenArangeStartStepOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();

// At this point all tensors should have value semantics, and hence the
// `layout` check can be ignored.

// TODO: Add support for pin_memory features.
// The pin_memory should be either `False` or `none`.
bool pinMemory;
if (!op.pin_memory().getType().isa<Torch::NoneType>() &&
(!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) ||
pinMemory)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: pin_memory must be either None or false");
}

int64_t start, step, end;
if (!matchPattern(op.start(), m_TorchConstantInt(&start)))
return rewriter.notifyMatchFailure(
op, "unimplemented: value `start` should be a torch constant int");

if (!matchPattern(op.end(), m_TorchConstantInt(&end)))
return rewriter.notifyMatchFailure(
op, "unimplemented: value `end` should be a torch constant int");

if (!matchPattern(op.step(), m_TorchConstantInt(&step)))
return rewriter.notifyMatchFailure(
op, "unimplemented: value `step` should be a torch constant int");

// The result will always be a 1-d tensor.
// The size of the result is calculated as follows:
// ceil((end - start)/step)
int64_t resultShape = ceil((float)(end - start) / (float)step);
SmallVector<int64_t> values(resultShape, start);
for (unsigned i = 1; i < resultShape; i++)
values[i] += i * step;
Value result =
tosa::getConstTensor<int64_t>(rewriter, op, values, resultShape).value();

rewriter.replaceOpWithNewOp<tosa::CastOp>(op, resultType, result);
return success();
}

template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
public:
Expand Down Expand Up @@ -3653,6 +3703,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenMaxDimOp);
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
Expand Down
19 changes: 19 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -828,3 +828,22 @@ func.func @torch.vtensor.literal_si64$basic() -> !torch.vtensor<[1,512],si64> {
%0 = torch.vtensor.literal(dense<-1> : tensor<1x512xsi64>) : !torch.vtensor<[1,512],si64>
return %0 : !torch.vtensor<[1,512],si64>
}

// -----
// CHECK-LABEL: func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[CST5:.*]] = torch.constant.int 5
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[VAL_0:.*]] = "tosa.const"() {value = dense<[0, 1, 2, 3, 4]> : tensor<5xi64>} : () -> tensor<5xi64>
// CHECK: %[[VAL_1:.*]] = "tosa.cast"(%[[VAL_0]]) : (tensor<5xi64>) -> tensor<5xi64>
// CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %1 : tensor<5xi64> -> !torch.vtensor<[5],si64>
// CHECK: return %[[VAL_2]] : !torch.vtensor<[5],si64>
func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> {
%none = torch.constant.none
%int0 = torch.constant.int 0
%int5 = torch.constant.int 5
%int1 = torch.constant.int 1
%0 = torch.aten.arange.start_step %int0, %int5, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[5],si64>
return %0 : !torch.vtensor<[5],si64>
}

0 comments on commit 9dd5ae8

Please sign in to comment.