From 5de2178c60352ad84dd3ef151d1360dabf9a6b51 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Wed, 19 Apr 2023 23:10:09 +0000 Subject: [PATCH] Add interpreter for ShiftLeftOp --- docs/spec.md | 10 ++++++---- docs/status.md | 2 +- stablehlo/dialect/StablehloOps.td | 5 +++-- stablehlo/reference/Element.cpp | 8 ++++++++ stablehlo/reference/Element.h | 3 +++ stablehlo/reference/Ops.cpp | 15 +++++++++++++++ stablehlo/reference/Ops.h | 2 ++ ...t_dtypes_lhs_int16_20_20__rhs_int16_20_20.mlir | 2 +- ...t_dtypes_lhs_int32_20_20__rhs_int32_20_20.mlir | 2 +- ...eft_dtypes_lhs_int8_20_20__rhs_int8_20_20.mlir | 2 +- ...dtypes_lhs_uint16_20_20__rhs_uint16_20_20.mlir | 2 +- ...dtypes_lhs_uint32_20_20__rhs_uint32_20_20.mlir | 2 +- ...t_dtypes_lhs_uint8_20_20__rhs_uint8_20_20.mlir | 2 +- stablehlo/tests/interpret_shift_left.mlir | 9 +++++++++ 14 files changed, 53 insertions(+), 13 deletions(-) create mode 100644 stablehlo/tests/interpret_shift_left.mlir diff --git a/docs/spec.md b/docs/spec.md index 2c3c723debd..8eb9e236a50 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -4944,12 +4944,14 @@ of bits and produces a `result` tensor. #### Examples ```mlir -// %lhs: [-1, -2, 3, 4, 7, 7] -// %rhs: [1, 2, 3, 6, 7, 8] -%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<6xi8>, tensor<6xi8>) -> tensor<6xi8> -// %result: [-2, -8, 24, 0, -128, 0] +// %lhs: [-1, 0, 1] +// %rhs: [1, 2, 3] +%result = "stablehlo.shift_left"(%lhs, %rhs): (tensor<3xi64>, tensor<3xi64>) -> tensor<3xi64> +// %result: [-2, 0, 8] ``` + [More Examples](../stablehlo/tests/interpret_shift_left.mlir) + ### shift_right_arithmetic #### Semantics diff --git a/docs/status.md b/docs/status.md index 5b3f68092c1..e2a56928125 100644 --- a/docs/status.md +++ b/docs/status.md @@ -137,7 +137,7 @@ one of the following tracking labels. | select_and_scatter | yes | revisit | yes | no | no | | send | yes | revisit | yes | no | no | | set_dimension_size | no | yes\* | yes\* | yes | no | -| shift_left | yes | yes | yes | yes | no | +| shift_left | yes | yes | yes | yes | yes | | shift_right_arithmetic | yes | yes | yes | yes | no | | shift_right_logical | yes | yes | yes | yes | no | | sign | yes | yes | yes | yes | yes | diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 366c5a5eda4..f8c8db46d37 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -849,7 +849,8 @@ def StableHLO_RemOp : StableHLO_BinaryElementwiseOp<"remainder", } def StableHLO_ShiftLeftOp : StableHLO_BinaryElementwiseOp<"shift_left", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_IntTensor> { + [Pure, HLO_CompatibleOperandsAndResultType /*shift_left_c1*/], + HLO_IntTensor /*shift_left_i1, shift_left_i2*/> { /*shift_left_c1*/ let summary = "ShiftLeft operation"; let description = [{ Performs element-wise left-shift operation on the `lhs` tensor by `rhs` @@ -860,7 +861,7 @@ def StableHLO_ShiftLeftOp : StableHLO_BinaryElementwiseOp<"shift_left", Example: ```mlir - %result = stablehlo.shift_left %lhs, %rhs : tensor<6xi8> + %result = stablehlo.shift_left %lhs, %rhs : tensor<3xi64> ``` }]; } diff --git a/stablehlo/reference/Element.cpp b/stablehlo/reference/Element.cpp index 737f27a072e..2675e150d7e 100644 --- a/stablehlo/reference/Element.cpp +++ b/stablehlo/reference/Element.cpp @@ -791,6 +791,14 @@ Element rsqrt(const Element &el) { [](std::complex e) { return 1.0 / std::sqrt(e); }); } +Element shiftLeft(const Element &e1, const Element &e2) { + auto type = e1.getType(); + if (!isSupportedIntegerType(type)) + report_fatal_error(invalidArgument("Unsupported element type: %s", + debugString(type).c_str())); + return Element(type, e1.getIntegerValue() << e2.getIntegerValue()); +} + Element sign(const Element &el) { Type type = el.getType(); diff --git a/stablehlo/reference/Element.h b/stablehlo/reference/Element.h index 5da3d6e39d9..6230c10019a 100644 --- a/stablehlo/reference/Element.h +++ b/stablehlo/reference/Element.h @@ -205,6 +205,9 @@ Element roundNearestEven(const Element &el); /// Returns reverse square root of Element object. Element rsqrt(const Element &e); +/// Returns left-shift of Element object e1 by e2. +Element shiftLeft(const Element &e1, const Element &e2); + /// Returns sign of Element object. Element sign(const Element &e); diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 13201a8fecb..cd33494a0b2 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -277,6 +277,12 @@ SmallVector eval( scope.find(selectOp.getPred()), scope.find(selectOp.getOnTrue()), scope.find(selectOp.getOnFalse()), selectOp.getType()); scope.add(op.getResults(), {runtimeResult}); + } else if (auto shiftLeftOp = dyn_cast(op)) { + Tensor runtimeLhs = scope.find(shiftLeftOp.getLhs()); + Tensor runtimeRhs = scope.find(shiftLeftOp.getRhs()); + Tensor runtimeResult = + evalShiftLeftOp(runtimeLhs, runtimeRhs, shiftLeftOp.getType()); + scope.add(op.getResults(), {runtimeResult}); } else if (auto signOp = dyn_cast(op)) { Tensor runtimeOperand = scope.find(signOp.getOperand()); Tensor runtimeResult = evalSignOp(runtimeOperand, signOp.getType()); @@ -770,6 +776,15 @@ Tensor evalSelectOp(const Tensor &pred, const Tensor &onTrue, return result; } +Tensor evalShiftLeftOp(const Tensor &lhs, const Tensor &rhs, + ShapedType resultType) { + Tensor result(resultType); + for (auto resultIt = result.index_begin(); resultIt != result.index_end(); + ++resultIt) + result.set(*resultIt, shiftLeft(lhs.get(*resultIt), rhs.get(*resultIt))); + return result; +} + Tensor evalSignOp(const Tensor &operand, ShapedType resultType) { Tensor result(resultType); for (auto it = result.index_begin(); it != result.index_end(); ++it) diff --git a/stablehlo/reference/Ops.h b/stablehlo/reference/Ops.h index 1c3cbecb4e2..70633ba3acd 100644 --- a/stablehlo/reference/Ops.h +++ b/stablehlo/reference/Ops.h @@ -87,6 +87,8 @@ Tensor evalRoundNearestEvenOp(const Tensor &operand, ShapedType resultType); Tensor evalRsqrtOp(const Tensor &operand, ShapedType resultType); Tensor evalSelectOp(const Tensor &pred, const Tensor &onTrue, const Tensor &onFalse, ShapedType resultType); +Tensor evalShiftLeftOp(const Tensor &lhs, const Tensor &rhs, + ShapedType resultType); Tensor evalSignOp(const Tensor &operand, ShapedType resultType); Tensor evalSineOp(const Tensor &operand, ShapedType resultType); Tensor evalSliceOp(const Tensor &operand, Index startIndices, Sizes strides, diff --git a/stablehlo/testdata/shift_left_dtypes_lhs_int16_20_20__rhs_int16_20_20.mlir b/stablehlo/testdata/shift_left_dtypes_lhs_int16_20_20__rhs_int16_20_20.mlir index 6cbb5bb9fe6..435365f7099 100644 --- a/stablehlo/testdata/shift_left_dtypes_lhs_int16_20_20__rhs_int16_20_20.mlir +++ b/stablehlo/testdata/shift_left_dtypes_lhs_int16_20_20__rhs_int16_20_20.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret +// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s) // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) diff --git a/stablehlo/testdata/shift_left_dtypes_lhs_int32_20_20__rhs_int32_20_20.mlir b/stablehlo/testdata/shift_left_dtypes_lhs_int32_20_20__rhs_int32_20_20.mlir index 2e5856fb4f5..5ccb2ecd945 100644 --- a/stablehlo/testdata/shift_left_dtypes_lhs_int32_20_20__rhs_int32_20_20.mlir +++ b/stablehlo/testdata/shift_left_dtypes_lhs_int32_20_20__rhs_int32_20_20.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret +// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s) // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) diff --git a/stablehlo/testdata/shift_left_dtypes_lhs_int8_20_20__rhs_int8_20_20.mlir b/stablehlo/testdata/shift_left_dtypes_lhs_int8_20_20__rhs_int8_20_20.mlir index 635538c5474..ef7292358ee 100644 --- a/stablehlo/testdata/shift_left_dtypes_lhs_int8_20_20__rhs_int8_20_20.mlir +++ b/stablehlo/testdata/shift_left_dtypes_lhs_int8_20_20__rhs_int8_20_20.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret +// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s) // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) diff --git a/stablehlo/testdata/shift_left_dtypes_lhs_uint16_20_20__rhs_uint16_20_20.mlir b/stablehlo/testdata/shift_left_dtypes_lhs_uint16_20_20__rhs_uint16_20_20.mlir index fef2731306b..ab9c661d06b 100644 --- a/stablehlo/testdata/shift_left_dtypes_lhs_uint16_20_20__rhs_uint16_20_20.mlir +++ b/stablehlo/testdata/shift_left_dtypes_lhs_uint16_20_20__rhs_uint16_20_20.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret +// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s) // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) diff --git a/stablehlo/testdata/shift_left_dtypes_lhs_uint32_20_20__rhs_uint32_20_20.mlir b/stablehlo/testdata/shift_left_dtypes_lhs_uint32_20_20__rhs_uint32_20_20.mlir index 97a866085d5..33ea480db5e 100644 --- a/stablehlo/testdata/shift_left_dtypes_lhs_uint32_20_20__rhs_uint32_20_20.mlir +++ b/stablehlo/testdata/shift_left_dtypes_lhs_uint32_20_20__rhs_uint32_20_20.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret +// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s) // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) diff --git a/stablehlo/testdata/shift_left_dtypes_lhs_uint8_20_20__rhs_uint8_20_20.mlir b/stablehlo/testdata/shift_left_dtypes_lhs_uint8_20_20__rhs_uint8_20_20.mlir index 0cb1d73a9e0..b0a9b3866d4 100644 --- a/stablehlo/testdata/shift_left_dtypes_lhs_uint8_20_20__rhs_uint8_20_20.mlir +++ b/stablehlo/testdata/shift_left_dtypes_lhs_uint8_20_20__rhs_uint8_20_20.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret +// RUN: stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt -inline | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --deserialize %s.0_9_0.bc | stablehlo-opt) <(stablehlo-opt %s) // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) diff --git a/stablehlo/tests/interpret_shift_left.mlir b/stablehlo/tests/interpret_shift_left.mlir new file mode 100644 index 00000000000..6e8a316a727 --- /dev/null +++ b/stablehlo/tests/interpret_shift_left.mlir @@ -0,0 +1,9 @@ +// RUN: stablehlo-translate --interpret -split-input-file %s + +func.func @shift_left_op_test_si64() { + %lhs = stablehlo.constant dense<[-1, 0, 1]> : tensor<3xi64> + %rhs = stablehlo.constant dense<[1, 2, 3]> : tensor<3xi64> + %result = stablehlo.shift_left %lhs, %rhs : tensor<3xi64> + check.expect_eq_const %result, dense<[-2, 0, 8]> : tensor<3xi64> + func.return +}