Skip to content

Commit

Permalink
[Torch] Emit and decompose prims.iota op (#3132)
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy authored Apr 22, 2024
1 parent a60e84e commit e5bdd71
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 0 deletions.
28 changes: 28 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15909,6 +15909,34 @@ def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [
let hasFolder = 1;
}

def Torch_PrimsIotaOp : Torch_Op<"prims.iota", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prims::iota : (int, int, int, int, Device, bool) -> (Tensor)`";
let arguments = (ins
Torch_IntType:$length,
Torch_IntType:$start,
Torch_IntType:$step,
Torch_IntType:$dtype,
Torch_DeviceType:$device,
Torch_BoolType:$requires_grad
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult PrimsIotaOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void PrimsIotaOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}

def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
HasValueSemantics,
AllowsTypeRefinement,
Expand Down
7 changes: 7 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8653,6 +8653,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.prims.iota\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.Device, %arg5: !torch.bool) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.prims.iota\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.Device, %arg5: !torch.bool) -> !torch.int {\n"
" return %arg3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.float) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4789,6 +4789,35 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern<AtenArangeStartOp> {
};
} // namespace

namespace {
// The `prims.iota` op is converted to `aten.arange.startStep` op.
class DecomposePrimsIotaOp : public OpRewritePattern<PrimsIotaOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimsIotaOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
int64_t length, start, step;
if (!matchPattern(op.getLength(), m_TorchConstantInt(&length)))
return rewriter.notifyMatchFailure(
op, "unimplemented: low must be a constant integer");
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
return rewriter.notifyMatchFailure(
op, "unimplemented: low must be a constant integer");
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step)))
return rewriter.notifyMatchFailure(
op, "unimplemented: low must be a constant integer");
auto endVal = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(start + length * step));
auto none = rewriter.create<ConstantNoneOp>(loc);
rewriter.replaceOpWithNewOp<AtenArangeStartStepOp>(
op, op.getType(), op.getStart(), endVal, op.getStep(), op.getDtype(),
none, op.getDevice(), none);
return success();
}
};
} // namespace

namespace {
// Decompose constant tensor full like ops.
template <typename OpTy, int fillVal>
Expand Down Expand Up @@ -7605,6 +7634,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsIotaOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
Expand Down
5 changes: 5 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,7 @@
"PrimMinIntDynamicModule_basic",
"PrimMinIntModule_basic",
"PrimsConvertElementTypeModule_basic",
"PrimsIotaModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic",
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
Expand Down Expand Up @@ -1789,6 +1790,7 @@
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"PrimListUnpackNumMismatchModule_basic",
"PrimsIotaModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic",
"PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic",
Expand Down Expand Up @@ -2683,6 +2685,9 @@
"SqueezeModule_allUnitDim",
"SqueezeModule_broadcast",
"SqueezeModule_static",

# RuntimeError: unsupported input type: Device
"PrimsIotaModule_basic",

# Failure - unknown
"BernoulliModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,12 @@ def prims〇view_of〡dtype(a_rank_dtype: Tuple[int, int]) -> int:
_, a_dtype = a_rank_dtype
return a_dtype

def prims〇iota〡shape(length: int, start: int, step: int, dtype: int, device: device, requires_grad: bool) -> List[int]:
return [length]

def prims〇iota〡dtype(length: int, start: int, step: int, dtype: int, device: device, requires_grad: bool) -> int:
return dtype

def prim〇NumToTensor〇Scalar〡shape(a: float) -> List[int]:
return []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("prims::split_dim : (Tensor, int, int) -> (Tensor)")
emit("prims::squeeze : (Tensor, int[]) -> (Tensor)")
emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True)
emit("prims::iota : (int, int, int, int, Device, bool) -> (Tensor)")

# ==========================================================================
# `quantized::` namespace.
Expand Down
17 changes: 17 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,20 @@ def forward(self):
@register_test_case(module_factory=lambda: LinspaceTwoSizeModule())
def LinspaceTwoSizeModule_basic(module, tu: TestUtils):
module.forward()


class PrimsIotaModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
])
def forward(self):
return torch.ops.prims.iota(77, start=0, step=1, dtype=torch.int64, device='cpu',
requires_grad=False)

@register_test_case(module_factory=lambda: PrimsIotaModule())
def PrimsIotaModule_basic(module, tu: TestUtils):
module.forward()

0 comments on commit e5bdd71

Please sign in to comment.