Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add interpreter for ShiftLeftOp
Browse files Browse the repository at this point in the history
ghpvnist committed Apr 19, 2023
1 parent 65ce9e6 commit 4184f64
Showing 14 changed files with 53 additions and 13 deletions.
10 changes: 6 additions & 4 deletions docs/spec.md
Original file line number Diff line number Diff line change
@@ -4944,12 +4944,14 @@ of bits and produces a `result` tensor.
#### Examples

```mlir
// %lhs: [-1, -2, 3, 4, 7, 7]
// %rhs: [1, 2, 3, 6, 7, 8]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<6xi8>, tensor<6xi8>) -> tensor<6xi8>
// %result: [-2, -8, 24, 0, -128, 0]
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
```

&nbsp;[More Examples](../stablehlo/tests/interpret_shift_left.mlir)

### shift_right_arithmetic

#### Semantics
2 changes: 1 addition & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
@@ -137,7 +137,7 @@ one of the following tracking labels.
| select_and_scatter | yes | revisit | yes | no | no |
| send | yes | revisit | yes | no | no |
| set_dimension_size | no | yes\* | yes\* | yes | no |
| shift_left | yes | yes | yes | yes | no |
| shift_left | yes | yes | yes | yes | yes |
| shift_right_arithmetic | yes | yes | yes | yes | no |
| shift_right_logical | yes | yes | yes | yes | no |
| sign | yes | yes | yes | yes | yes |
5 changes: 3 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
@@ -849,7 +849,8 @@ def StableHLO_RemOp : StableHLO_BinaryElementwiseOp<"remainder",
}

def StableHLO_ShiftLeftOp : StableHLO_BinaryElementwiseOp<"shift_left",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntTensor> {
[Pure, HLO_CompatibleOperandsAndResultType /*shift_left_c1*/],
HLO_IntTensor /*shift_left_i1, shift_left_i2*/> { /*shift_left_c1*/
let summary = "ShiftLeft operation";
let description = [{
Performs element-wise left-shift operation on the `lhs` tensor by `rhs`
@@ -860,7 +861,7 @@ def StableHLO_ShiftLeftOp : StableHLO_BinaryElementwiseOp<"shift_left",

Example:
```mlir
%result = stablehlo.shift_left %lhs, %rhs : tensor<6xi8>
%result = stablehlo.shift_left %lhs, %rhs : tensor<3xi64>
```
}];
}
8 changes: 8 additions & 0 deletions stablehlo/reference/Element.cpp
Original file line number Diff line number Diff line change
@@ -791,6 +791,14 @@ Element rsqrt(const Element &el) {
[](std::complex<double> e) { return 1.0 / std::sqrt(e); });
}

Element shiftLeft(const Element &e1, const Element &e2) {
auto type = e1.getType();
if (!isSupportedIntegerType(type))
report_fatal_error(invalidArgument("Unsupported element type: %s",
debugString(type).c_str()));
return Element(type, e1.getIntegerValue() << e2.getIntegerValue());
}

Element sign(const Element &el) {
Type type = el.getType();

3 changes: 3 additions & 0 deletions stablehlo/reference/Element.h
Original file line number Diff line number Diff line change
@@ -205,6 +205,9 @@ Element roundNearestEven(const Element &el);
/// Returns reverse square root of Element object.
Element rsqrt(const Element &e);

/// Returns left shift of Element object e1 by e2.
Element shiftLeft(const Element &e1, const Element &e2);

/// Returns sign of Element object.
Element sign(const Element &e);

15 changes: 15 additions & 0 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
@@ -277,6 +277,12 @@ SmallVector<Tensor> eval(
scope.find(selectOp.getPred()), scope.find(selectOp.getOnTrue()),
scope.find(selectOp.getOnFalse()), selectOp.getType());
scope.add(op.getResults(), {runtimeResult});
} else if (auto shiftLeftOp = dyn_cast<ShiftLeftOp>(op)) {
Tensor runtimeLhs = scope.find(shiftLeftOp.getLhs());
Tensor runtimeRhs = scope.find(shiftLeftOp.getRhs());
Tensor runtimeResult =
evalShiftLeftOp(runtimeLhs, runtimeRhs, shiftLeftOp.getType());
scope.add(op.getResults(), {runtimeResult});
} else if (auto signOp = dyn_cast<SignOp>(op)) {
Tensor runtimeOperand = scope.find(signOp.getOperand());
Tensor runtimeResult = evalSignOp(runtimeOperand, signOp.getType());
@@ -770,6 +776,15 @@ Tensor evalSelectOp(const Tensor &pred, const Tensor &onTrue,
return result;
}

Tensor evalShiftLeftOp(const Tensor &lhs, const Tensor &rhs,
ShapedType resultType) {
Tensor result(resultType);
for (auto resultIt = result.index_begin(); resultIt != result.index_end();
++resultIt)
result.set(*resultIt, shiftLeft(lhs.get(*resultIt), rhs.get(*resultIt)));
return result;
}

Tensor evalSignOp(const Tensor &operand, ShapedType resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it)
2 changes: 2 additions & 0 deletions stablehlo/reference/Ops.h
Original file line number Diff line number Diff line change
@@ -87,6 +87,8 @@ Tensor evalRoundNearestEvenOp(const Tensor &operand, ShapedType resultType);
Tensor evalRsqrtOp(const Tensor &operand, ShapedType resultType);
Tensor evalSelectOp(const Tensor &pred, const Tensor &onTrue,
const Tensor &onFalse, ShapedType resultType);
Tensor evalShiftLeftOp(const Tensor &lhs, const Tensor &rhs,
ShapedType resultType);
Tensor evalSignOp(const Tensor &operand, ShapedType resultType);
Tensor evalSineOp(const Tensor &operand, ShapedType resultType);
Tensor evalSliceOp(const Tensor &operand, Index startIndices, Sizes strides,
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

9 changes: 9 additions & 0 deletions stablehlo/tests/interpret_shift_left.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

func.func @shift_left_op_test_si64() {
%lhs = stablehlo.constant dense<[-1, 0, 1]> : tensor<3xi64>
%rhs = stablehlo.constant dense<[1, 2, 3]> : tensor<3xi64>
%result = stablehlo.shift_left %lhs, %rhs : tensor<3xi64>
check.expect_eq_const %result, dense<[-2, 0, 8]> : tensor<3xi64>
func.return
}

0 comments on commit 4184f64

Please sign in to comment.