Skip to content

Commit

Permalink
Add interpreter for ShiftRightLogicalOp (#1429)
Browse files Browse the repository at this point in the history
Here are the constraints for the ShiftRightLogicalOp:
```
(I1) lhs is a tensor of integer type.
(I2) rhs is a tensor of integer type.
(C1) `lhs`, `rhs`, and `result` have the same type.
```
I1, I2, and C1 are covered by the ODS, so no additional tests are added.

Notes:
* Corner cases (shift overflow) has not been accounted for: #1150

closes #1114
  • Loading branch information
ghpvnist authored Apr 20, 2023
1 parent 86b2fa6 commit a672298
Show file tree
Hide file tree
Showing 14 changed files with 54 additions and 13 deletions.
10 changes: 6 additions & 4 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -5014,12 +5014,14 @@ number of bits and produces a `result` tensor.
#### Examples

```mlir
// %lhs: [-1, -128, -36, 5, 3, 7]
// %rhs: [1, 2, 3, 2, 1, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<6xi8>, tensor<6xi8>) -> tensor<6xi8>
// %result: [127, 32, 27, 1, 1, 0]
// %lhs: [-1, 0, 8]
// %rhs: [1, 2, 3]
%result = "stablehlo.shift_right_logical"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64>
// %result: [9223372036854775807, 0, 1]
```

&nbsp;[More Examples](../stablehlo/tests/interpret_shift_right_logical.mlir)

### sign

#### Semantics
Expand Down
2 changes: 1 addition & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ one of the following tracking labels.
| set_dimension_size | no | yes\* | yes\* | yes | no |
| shift_left | yes | yes | yes | yes | yes |
| shift_right_arithmetic | yes | yes | yes | yes | no |
| shift_right_logical | yes | yes | yes | yes | no |
| shift_right_logical | yes | yes | yes | yes | yes |
| sign | yes | yes | yes | yes | yes |
| sine | yes | yes | yes | yes | yes |
| slice | yes | yes | yes | no | yes |
Expand Down
5 changes: 3 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,8 @@ def StableHLO_ShiftRightArithmeticOp : StableHLO_BinaryElementwiseOp<"shift_righ
}

def StableHLO_ShiftRightLogicalOp : StableHLO_BinaryElementwiseOp<"shift_right_logical",
[Pure, HLO_CompatibleOperandsAndResultType], HLO_IntTensor> {
[Pure, HLO_CompatibleOperandsAndResultType /*shift_right_logical_c1*/],
HLO_IntTensor /*shift_right_logical_i1, shift_right_logical_i2*/> { /*shift_right_logical_c1*/
let summary = "ShiftRightLogical operation";
let description = [{
Performs element-wise logical right-shift operation on the `lhs` tensor by
Expand All @@ -895,7 +896,7 @@ def StableHLO_ShiftRightLogicalOp : StableHLO_BinaryElementwiseOp<"shift_right_l

Example:
```mlir
%result = stablehlo.shift_right_logical %lhs, %rhs : tensor<6xi8>
%result = stablehlo.shift_right_logical %lhs, %rhs : tensor<3xi64>
```
}];
}
Expand Down
8 changes: 8 additions & 0 deletions stablehlo/reference/Element.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,14 @@ Element shiftLeft(const Element &e1, const Element &e2) {
return Element(type, e1.getIntegerValue() << e2.getIntegerValue());
}

Element shiftRightLogical(const Element &e1, const Element &e2) {
auto type = e1.getType();
if (!isSupportedIntegerType(type))
report_fatal_error(invalidArgument("Unsupported element type: %s",
debugString(type).c_str()));
return Element(type, e1.getIntegerValue().lshr(e2.getIntegerValue()));
}

Element sign(const Element &el) {
Type type = el.getType();

Expand Down
3 changes: 3 additions & 0 deletions stablehlo/reference/Element.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ Element rsqrt(const Element &e);
/// Returns left-shift of Element object e1 by e2.
Element shiftLeft(const Element &e1, const Element &e2);

/// Returns logical right-shift of Element object e1 by e2.
Element shiftRightLogical(const Element &e1, const Element &e2);

/// Returns sign of Element object.
Element sign(const Element &e);

Expand Down
16 changes: 16 additions & 0 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,12 @@ SmallVector<Tensor> eval(
Tensor runtimeResult =
evalShiftLeftOp(runtimeLhs, runtimeRhs, shiftLeftOp.getType());
scope.add(op.getResults(), {runtimeResult});
} else if (auto shiftRightLogicalOp = dyn_cast<ShiftRightLogicalOp>(op)) {
Tensor runtimeLhs = scope.find(shiftRightLogicalOp.getLhs());
Tensor runtimeRhs = scope.find(shiftRightLogicalOp.getRhs());
Tensor runtimeResult = evalShiftRightLogicalOp(
runtimeLhs, runtimeRhs, shiftRightLogicalOp.getType());
scope.add(op.getResults(), {runtimeResult});
} else if (auto signOp = dyn_cast<SignOp>(op)) {
Tensor runtimeOperand = scope.find(signOp.getOperand());
Tensor runtimeResult = evalSignOp(runtimeOperand, signOp.getType());
Expand Down Expand Up @@ -799,6 +805,16 @@ Tensor evalShiftLeftOp(const Tensor &lhs, const Tensor &rhs,
return result;
}

Tensor evalShiftRightLogicalOp(const Tensor &lhs, const Tensor &rhs,
ShapedType resultType) {
Tensor result(resultType);
for (auto resultIt = result.index_begin(); resultIt != result.index_end();
++resultIt)
result.set(*resultIt,
shiftRightLogical(lhs.get(*resultIt), rhs.get(*resultIt)));
return result;
}

Tensor evalSignOp(const Tensor &operand, ShapedType resultType) {
Tensor result(resultType);
for (auto it = result.index_begin(); it != result.index_end(); ++it)
Expand Down
2 changes: 2 additions & 0 deletions stablehlo/reference/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ Tensor evalSelectOp(const Tensor &pred, const Tensor &onTrue,
const Tensor &onFalse, ShapedType resultType);
Tensor evalShiftLeftOp(const Tensor &lhs, const Tensor &rhs,
ShapedType resultType);
Tensor evalShiftRightLogicalOp(const Tensor &lhs, const Tensor &rhs,
ShapedType resultType);
Tensor evalSignOp(const Tensor &operand, ShapedType resultType);
Tensor evalSineOp(const Tensor &operand, ShapedType resultType);
Tensor evalSliceOp(const Tensor &operand, Index startIndices, Sizes strides,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s)
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

Expand Down
9 changes: 9 additions & 0 deletions stablehlo/tests/interpret_shift_right_logical.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

func.func @shift_right_logical_op_test_si64() {
%lhs = stablehlo.constant dense<[-1, 0, 8]> : tensor<3xi64>
%rhs = stablehlo.constant dense<[1, 2, 3]> : tensor<3xi64>
%result = stablehlo.shift_right_logical %lhs, %rhs : tensor<3xi64>
check.expect_eq_const %result, dense<[9223372036854775807, 0, 1]> : tensor<3xi64>
func.return
}

0 comments on commit a672298

Please sign in to comment.