Skip to content

Commit

Permalink
[mlir][complex] Canonicalize complex.div by one (llvm#85513)
Browse files Browse the repository at this point in the history
We can canonicalize the complex.div if the divisor is one (real = 1.0,
imag = 0.0) with the input number itself.

Ref: https://www.cuemath.com/numbers/division-of-complex-numbers/
  • Loading branch information
Lewuathe authored Mar 26, 2024
1 parent 3e046ee commit 08a321e
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def DivOp : ComplexArithmeticOp<"div"> {
%a = complex.div %b, %c : complex<f32>
```
}];

let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,32 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
return {};
}

//===----------------------------------------------------------------------===//
// DivOp
//===----------------------------------------------------------------------===//

OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
auto rhs = adaptor.getRhs();
if (!rhs)
return {};

ArrayAttr arrayAttr = rhs.dyn_cast<ArrayAttr>();
if (!arrayAttr || arrayAttr.size() != 2)
return {};

APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();

if (!imag.isZero())
return {};

// complex.div(a, complex.constant<1.0, 0.0>) -> a
if (real == APFloat(real.getSemantics(), 1))
return getLhs();

return {};
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
56 changes: 56 additions & 0 deletions mlir/test/Dialect/Complex/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,59 @@ func.func @double_reverse_bitcast(%arg0 : complex<f32>) -> f64 {
// CHECK: return %[[R0]] : f64
func.return %1 : f64
}


// CHECK-LABEL: func @div_one_f16
// CHECK-SAME: (%[[ARG0:.*]]: f16, %[[ARG1:.*]]: f16) -> complex<f16>
func.func @div_one_f16(%arg0: f16, %arg1: f16) -> complex<f16> {
%create = complex.create %arg0, %arg1: complex<f16>
%one = complex.constant [1.0 : f16, 0.0 : f16] : complex<f16>
%div = complex.div %create, %one : complex<f16>
// CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f16>
// CHECK-NEXT: return %[[CREATE]]
return %div : complex<f16>
}

// CHECK-LABEL: func @div_one_f32
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) -> complex<f32>
func.func @div_one_f32(%arg0: f32, %arg1: f32) -> complex<f32> {
%create = complex.create %arg0, %arg1: complex<f32>
%one = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
%div = complex.div %create, %one : complex<f32>
// CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f32>
// CHECK-NEXT: return %[[CREATE]]
return %div : complex<f32>
}

// CHECK-LABEL: func @div_one_f64
// CHECK-SAME: (%[[ARG0:.*]]: f64, %[[ARG1:.*]]: f64) -> complex<f64>
func.func @div_one_f64(%arg0: f64, %arg1: f64) -> complex<f64> {
%create = complex.create %arg0, %arg1: complex<f64>
%one = complex.constant [1.0 : f64, 0.0 : f64] : complex<f64>
%div = complex.div %create, %one : complex<f64>
// CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f64>
// CHECK-NEXT: return %[[CREATE]]
return %div : complex<f64>
}

// CHECK-LABEL: func @div_one_f80
// CHECK-SAME: (%[[ARG0:.*]]: f80, %[[ARG1:.*]]: f80) -> complex<f80>
func.func @div_one_f80(%arg0: f80, %arg1: f80) -> complex<f80> {
%create = complex.create %arg0, %arg1: complex<f80>
%one = complex.constant [1.0 : f80, 0.0 : f80] : complex<f80>
%div = complex.div %create, %one : complex<f80>
// CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f80>
// CHECK-NEXT: return %[[CREATE]]
return %div : complex<f80>
}

// CHECK-LABEL: func @div_one_f128
// CHECK-SAME: (%[[ARG0:.*]]: f128, %[[ARG1:.*]]: f128) -> complex<f128>
func.func @div_one_f128(%arg0: f128, %arg1: f128) -> complex<f128> {
%create = complex.create %arg0, %arg1: complex<f128>
%one = complex.constant [1.0 : f128, 0.0 : f128] : complex<f128>
%div = complex.div %create, %one : complex<f128>
// CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f128>
// CHECK-NEXT: return %[[CREATE]]
return %div : complex<f128>
}

0 comments on commit 08a321e

Please sign in to comment.