Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement lowering of torch.aten.all.dim #2873

Merged
merged 4 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions lib/Conversion/TorchToLinalg/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc,
if (isa<AtenLinalgVectorNormOp>(op) || isa<AtenFrobeniusNormDimOp>(op))
return b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));

if (auto all = dyn_cast<AtenAllDimOp>(op)) {
mmakevic marked this conversation as resolved.
Show resolved Hide resolved
return b.create<arith::ConstantOp>(loc, b.getBoolAttr(true));
}

op->emitError("unimplemented lowering in createInitElementForReduceOp");
return nullptr;
}
Expand Down Expand Up @@ -357,6 +361,11 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
auto ord = b.create<arith::ConstantOp>(loc, twoAttr);
auto pow = b.create<math::PowFOp>(loc, abs, ord);
return b.create<arith::AddFOp>(loc, pow, result);
} else if (auto allOp = dyn_cast<AtenAllDimOp>(op)) {
mmakevic marked this conversation as resolved.
Show resolved Hide resolved
Value elem = payloadArgs[0];
Value result = payloadArgs[1];
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
return b.create<arith::MulIOp>(loc, self, result);
}
op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp");
return nullptr;
Expand Down Expand Up @@ -447,6 +456,9 @@ class ConvertReductionOp : public ConversionPattern {
if (auto normOp = dyn_cast<AtenFrobeniusNormDimOp>(op))
return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter);

if (auto allOp = dyn_cast<AtenAllDimOp>(op))
return computeReductionOpInfoForDimVariantOp(allOp, operands, rewriter);

return rewriter.notifyMatchFailure(op, "not a supported reduce op");
}

Expand Down Expand Up @@ -535,6 +547,9 @@ class ConvertReductionOp : public ConversionPattern {
!elemType.isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(
op, "only float types are valid for vector norm ops");
if((isa<AtenAllDimOp>(op)) && elemType.isa<mlir::IntegerType>() &&
mmakevic marked this conversation as resolved.
Show resolved Hide resolved
elemType.getIntOrFloatBitWidth() == 8)
return rewriter.notifyMatchFailure(op, "uint8 is not supported");
// No checks for all other reduction operations
return success();
}
Expand Down Expand Up @@ -610,6 +625,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
target.addIllegalOp<AtenProdDimIntOp>();
target.addIllegalOp<AtenMaxOp>();
target.addIllegalOp<AtenMinOp>();
target.addIllegalOp<AtenAllDimOp>();
target.addIllegalOp<AtenLinalgVectorNormOp>();
target.addIllegalOp<AtenFrobeniusNormDimOp>();
patterns.add<ConvertReductionOp>(typeConverter, context);
Expand Down
17 changes: 17 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7006,6 +7006,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.all.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max.dim\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = torch.derefine %arg1 : !torch.int to !torch.optional<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
Expand Down Expand Up @@ -11809,6 +11814,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.all.dim\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" %int0 = torch.constant.int 0\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" torch.prim.If.yield %0#1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int11 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.min\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,9 @@ def aten〇one_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]:
def aten〇any〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]:
return upstream_shape_functions.argmax(self, dim, keepdim)

def aten〇all〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]:
return upstream_shape_functions.argmax(self, dim, keepdim)

def aten〇max〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> Tuple[List[int], List[int]]:
reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim)
return reduced_shape, reduced_shape
Expand Down Expand Up @@ -3766,6 +3769,13 @@ def aten〇any〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim
return self_dtype
return torch.bool

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
def aten〇all〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
if self_dtype == torch.uint8:
return self_dtype
return torch.bool

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇min〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
54 changes: 54 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,60 @@ def ReduceProdDimIntFloatModule_basic(module, tu: TestUtils):

# ==============================================================================

class ReduceAllDimFloat(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1,-1,-1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.all(a, dim=1, keepdim=True)

@register_test_case(module_factory=lambda: ReduceAllDimFloat())
def ReduceAllDimFloat_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================

class ReduceAllDimInt(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1,-1,-1], torch.int32, True),
])
def forward(self, a):
return torch.ops.aten.all(a, dim=1, keepdim=True)

@register_test_case(module_factory=lambda: ReduceAllDimInt())
def ReduceAllDimInt_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5).to(torch.int32))

# ==============================================================================

class ReduceAllDimBool(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1,-1,-1], torch.bool, True),
])
def forward(self, a):
return torch.ops.aten.all(a, dim=1, keepdim=False)

@register_test_case(module_factory=lambda: ReduceAllDimBool())
def ReduceAllDimBool_basic(module, tu: TestUtils):
module.forward(tu.randint(3,4,5, low=0, high=2).bool())

# ==============================================================================

class ReduceMaxAlongDim(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down