Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interpreter for ShiftLeftOp #1428

Merged
merged 1 commit into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -4944,12 +4944,14 @@ of bits and produces a `result` tensor.
#### Examples

```mlir
// %lhs: [-1, -2, 3, 4, 7, 7]
// %rhs: [1, 2, 3, 6, 7, 8]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<6xi8>, tensor<6xi8>) -> tensor<6xi8>
// %result: [-2, -8, 24, 0, -128, 0]
// %lhs: [-1, 0, 1]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [-2, 0, 8]
```

&nbsp;[More Examples](../stablehlo/tests/interpret_shift_left.mlir)

### shift_right_arithmetic

#### Semantics
Expand Down
2 changes: 1 addition & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ one of the following tracking labels.
| select_and_scatter | yes | revisit | yes | no | no |
| send | yes | revisit | yes | no | no |
| set_dimension_size | no | yes\* | yes\* | yes | no |
| shift_left | yes | yes | yes | yes | no |
| shift_left | yes | yes | yes | yes | yes |
| shift_right_arithmetic | yes | yes | yes | yes | no |
| shift_right_logical | yes | yes | yes | yes | no |
| sign | yes | yes | yes | yes | yes |
Expand Down
5 changes: 3 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,8 @@ def StableHLO_RemOp : StableHLO_BinaryElementwiseOp<"remainder",
}

def StableHLO_ShiftLeftOp : StableHLO_BinaryElementwiseOp<"shift_left",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntTensor> {
[Pure, HLO_CompatibleOperandsAndResultType /*shift_left_c1*/],
HLO_IntTensor /*shift_left_i1, shift_left_i2*/> { /*shift_left_c1*/
let summary = "ShiftLeft operation";
let description = [{
Performs element-wise left-shift operation on the `lhs` tensor by `rhs`
Expand All @@ -860,7 +861,7 @@ def StableHLO_ShiftLeftOp : StableHLO_BinaryElementwiseOp<"shift_left",

Example:
```mlir
%result = stablehlo.shift_left %lhs, %rhs : tensor<6xi8>
%result = stablehlo.shift_left %lhs, %rhs : tensor<3xi64>
```
}];
}
Expand Down
8 changes: 8 additions & 0 deletions stablehlo/reference/Element.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,14 @@ Element rsqrt(const Element &el) {
[](std::complex<double> e) { return 1.0 / std::sqrt(e); });
}

Element shiftLeft(const Element &e1, const Element &e2) {
auto type = e1.getType();
if (!isSupportedIntegerType(type))
report_fatal_error(invalidArgument("Unsupported element type: %s",
debugString(type).c_str()));
return Element(type, e1.getIntegerValue() << e2.getIntegerValue());
}

Element sign(const Element &el) {
Type type = el.getType();

Expand Down
3 changes: 3 additions & 0 deletions stablehlo/reference/Element.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ Element roundNearestEven(const Element &el);
/// Returns reverse square root of Element object.
Element rsqrt(const Element &e);

/// Returns left-shift of Element object e1 by e2.
Element shiftLeft(const Element &e1, const Element &e2);

/// Returns sign of Element object.
Element sign(const Element &e);

Expand Down
15 changes: 15 additions & 0 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,12 @@ SmallVector<Tensor> eval(
scope.find(selectOp.getPred()), scope.find(selectOp.getOnTrue()),
scope.find(selectOp.getOnFalse()), selectOp.getType());
scope.add(op.getResults(), {runtimeResult});
} else if (auto shiftLeftOp = dyn_cast<ShiftLeftOp>(op)) {
Tensor runtimeLhs = scope.find(shiftLeftOp.getLhs());
Tensor runtimeRhs = scope.find(shiftLeftOp.getRhs());
Tensor runtimeResult =
evalShiftLeftOp(runtimeLhs, runtimeRhs, shiftLeftOp.getType());
scope.add(op.getResults(), {runtimeResult});
} else if (auto signOp = dyn_cast<SignOp>(op)) {
Tensor runtimeOperand = scope.find(signOp.getOperand());
Tensor runtimeResult = evalSignOp(runtimeOperand, signOp.getType());
Expand Down Expand Up @@ -770,6 +776,15 @@ Tensor evalSelectOp(const Tensor &pred, const Tensor &onTrue,
return result;
}

Tensor evalShiftLeftOp(const Tensor &lhs, const Tensor &rhs,
ShapedType resultType) {
Tensor result(resultType);
for (auto resultIt = result.index_begin(); resultIt != result.index_end();
++resultIt)
result.set(*resultIt, shiftLeft(lhs.get(*resultIt), rhs.get(*resultIt)));
return result;
}

Tensor evalSignOp(const Tensor &operand, ShapedType resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it)
Expand Down
2 changes: 2 additions & 0 deletions stablehlo/reference/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ Tensor evalRoundNearestEvenOp(const Tensor &operand, ShapedType resultType);
Tensor evalRsqrtOp(const Tensor &operand, ShapedType resultType);
Tensor evalSelectOp(const Tensor &pred, const Tensor &onTrue,
const Tensor &onFalse, ShapedType resultType);
Tensor evalShiftLeftOp(const Tensor &lhs, const Tensor &rhs,
ShapedType resultType);
Tensor evalSignOp(const Tensor &operand, ShapedType resultType);
Tensor evalSineOp(const Tensor &operand, ShapedType resultType);
Tensor evalSliceOp(const Tensor &operand, Index startIndices, Sizes strides,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Expand Down
9 changes: 9 additions & 0 deletions stablehlo/tests/interpret_shift_left.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

func.func @shift_left_op_test_si64() {
%lhs = stablehlo.constant dense<[-1, 0, 1]> : tensor<3xi64>
%rhs = stablehlo.constant dense<[1, 2, 3]> : tensor<3xi64>
%result = stablehlo.shift_left %lhs, %rhs : tensor<3xi64>
check.expect_eq_const %result, dense<[-2, 0, 8]> : tensor<3xi64>
func.return
}