Skip to content

Commit

Permalink
add e2e support for aten.atan2 (#1117)
Browse files Browse the repository at this point in the history
 - Includes math-to-libm pass in refbackend for math::atan2 support
  • Loading branch information
qedawkins authored Aug 2, 2022
1 parent 704efdc commit 38d8498
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 9 deletions.
47 changes: 47 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,53 @@ def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [
}];
}

def Torch_AtenAtan2Op : Torch_Op<"aten.atan2", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::atan2 : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAtan2Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenAtan2Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenAtan2_Op : Torch_Op<"aten.atan2_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::atan2_ : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAtan2_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenAtan2_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenNegOp : Torch_Op<"aten.neg", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
30 changes: 21 additions & 9 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<arith::MulIOp>(loc, lhs, rhs);
}
}
if (auto atan2 = dyn_cast<AtenAtan2Op>(op)) {
Type dtype = converter->convertType(atan2.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::FloatType>()) {
atan2.emitError("Atan2 requires floating point result type");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<math::Atan2Op>(loc, lhs, rhs);
}
if (auto gtTensor = dyn_cast<AtenGtTensorOp>(op)) {
AtenGtTensorOp::Adaptor adaptor(operands);
Type lhsDtype = payloadArgs[0].getType();
Expand Down Expand Up @@ -926,7 +938,7 @@ class ConvertElementwiseOp : public ConversionPattern {
ConversionPatternRewriter &rewriter) const override {
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp,
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp,
AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenAtan2Op,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
Expand Down Expand Up @@ -1669,14 +1681,14 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp,
AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp,
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
AtenLogicalOrOp, AtenTriuOp>();
AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp,
AtenPowTensorScalarOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
Expand Down
17 changes: 17 additions & 0 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,23 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}

// Dtype is always float32, except for bfloat16, float64 and nullptr after
// promotion and assuming possible-zero rank.
if (isa<AtenAtan2Op>(op)) {
ValueKnowledge knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type promotedDtype = getPromotedResultType(
op->getContext(), {&operands[0]->getValue(), &operands[1]->getValue()},
getRankIsNonZeroArray(op->getOperands()));
if (promotedDtype) {
knowledge.dtype = Float32Type::get(op->getContext());
if (promotedDtype.isa<BFloat16Type, Float64Type>())
knowledge.dtype = promotedDtype;
}
incorporateKnowledge(op->getResult(0), knowledge);
return;
}

// Promote three dtypes.
if (isa<AtenAddmmOp, AtenLerpTensorOp, AtenAddcmulOp, AtenAddcdivOp>(op)) {
auto knowledge =
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6248,6 +6248,10 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.atan2"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.__and__.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,9 @@ def aten〇div〇Tensor_mode(self: List[int], other: List[int], rounding_mode: O
def aten〇floor_divide(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

def aten〇atan2(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

def aten〇__and__〇Tensor(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::exp : (Tensor) -> (Tensor)",
"aten::expm1 : (Tensor) -> (Tensor)",
"aten::cos : (Tensor) -> (Tensor)",
"aten::atan2 : (Tensor, Tensor) -> (Tensor)",
"aten::neg : (Tensor) -> (Tensor)",
"aten::floor : (Tensor) -> (Tensor)",
"aten::ceil : (Tensor) -> (Tensor)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def invoke(*args):
"func.func(refback-expand-ops-for-llvm)",
"func.func(arith-expand)",
"func.func(convert-math-to-llvm)",
# Handle some complex mlir::math ops (e.g. atan2)
"convert-math-to-libm",
"convert-linalg-to-llvm",
"convert-memref-to-llvm",
"func.func(convert-arith-to-llvm)",
Expand Down
71 changes: 71 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,77 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseAtan2TensorFloatModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, a, b):
return torch.atan2(a, b)


@register_test_case(module_factory=lambda: ElementwiseAtan2TensorFloatModule())
def ElementwiseAtan2TensorFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4), tu.rand(4, 4))


# ==============================================================================


class ElementwiseAtan2TensorIntModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1], torch.int32, True),
([-1], torch.int64, True),
])
def forward(self, a, b):
return torch.atan2(a, b)


@register_test_case(module_factory=lambda: ElementwiseAtan2TensorIntModule())
def ElementwiseAtan2TensorIntModule_basic(module, tu: TestUtils):
module.forward(
torch.randint(1, 10, [4]).type(torch.int32), torch.randint(1, 10, [4]))


# ==============================================================================


class ElementwiseAtan2FloatIntModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
([-1, -1], torch.float64, True),
])
def forward(self, a, b):
return torch.atan2(a, b)


@register_test_case(module_factory=lambda: ElementwiseAtan2FloatIntModule())
def ElementwiseAtan2FloatIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, [4, 4], dtype=torch.int32),
tu.rand(4, 4).double())


# ==============================================================================


class ElementwiseLogModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit 38d8498

Please sign in to comment.