Skip to content

Commit

Permalink
Pretty print for ops with single enum attribute (openxla#161)
Browse files Browse the repository at this point in the history
See #25 for information, or view test updates for examples of new prettyprint.

Implements prettyprint for the following ops:

```
DotOp, RngBitGeneratorOp, RngOp
```

Closes #25
  • Loading branch information
GleasonK committed Oct 24, 2022
1 parent c822026 commit 82df29a
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 32 deletions.
8 changes: 2 additions & 6 deletions stablehlo/dialect/StablehloEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def StableHLO_RngDistribution : I32EnumAttr<"RngDistribution",
let cppNamespace = "::mlir::stablehlo";
}

def StableHLO_RngDistributionAttr : EnumAttr<StableHLO_Dialect, StableHLO_RngDistribution, "rng_distribution"> {
let assemblyFormat = "`<` $value `>`";
}
def StableHLO_RngDistributionAttr : EnumAttr<StableHLO_Dialect, StableHLO_RngDistribution, "rng_distribution">;

def STABLEHLO_RNG_ALGORITHM_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>;
def STABLEHLO_RNG_ALGORITHM_THREE_FRY : I32EnumAttrCase<"THREE_FRY", 1>;
Expand All @@ -197,8 +195,6 @@ def StableHLO_RngAlgorithm : I32EnumAttr<"RngAlgorithm",
let cppNamespace = "::mlir::stablehlo";
}

def StableHLO_RngAlgorithmAttr : EnumAttr<StableHLO_Dialect, StableHLO_RngAlgorithm, "rng_algorithm"> {
let assemblyFormat = "`<` $value `>`";
}
def StableHLO_RngAlgorithmAttr : EnumAttr<StableHLO_Dialect, StableHLO_RngAlgorithm, "rng_algorithm">;

#endif // STABLEHLO_DIALECT_STABLEHLO_ENUMS
38 changes: 38 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,44 @@ LogicalResult DotOp::verify() {
return success();
}

// PrecisionConfig - Optional attribute, print the array as raw enums
//
// {precision_config = [#stablehlo<precision DEFAULT>,
// #stablehlo<precision DEFAULT>]}
// ==> ..., precision = [DEFAULT, DEFAULT]
void printPrecisionConfig(OpAsmPrinter& p, Operation*,
::mlir::ArrayAttr attrArr) {
// Precision config is an optional attribute, passes null if not specified.
if (!attrArr) return;

p << ", precision = [";
llvm::interleaveComma(attrArr, p, [&](Attribute const& attr) {
p << stringifyPrecision(attr.cast<PrecisionAttr>().getValue());
});
p << ']';
}

ParseResult parsePrecisionConfig(OpAsmParser& parser, mlir::ArrayAttr& attr) {
if (failed(parser.parseOptionalComma())) {
return success(); // No precision config specified
}

if (failed(parser.parseKeyword("precision")) || failed(parser.parseEqual()))
return failure();

SmallVector<Attribute> attrs;
if (failed(parser.parseCommaSeparatedList(
AsmParser::Delimiter::Square, [&]() -> ParseResult {
attrs.push_back(PrecisionAttr::parse(parser, {}));
return success(/*isSuccess=*/bool(attrs.back()));
}))) {
return failure();
}

attr = mlir::ArrayAttr::get(parser.getContext(), attrs);
return success();
}

//===----------------------------------------------------------------------===//
// DotGeneralOp
//===----------------------------------------------------------------------===//
Expand Down
35 changes: 35 additions & 0 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2011,6 +2011,13 @@ def StableHLO_DotOp: StableHLO_Op<"dot",
multiplication.

See https://www.tensorflow.org/xla/operation_semantics#dot.

Example:

```mlir
%0 = stablehlo.dot %arg0, %arg1 : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<1x1xi32>
%1 = stablehlo.dot %arg0, %arg1, precision = [DEFAULT, DEFAULT] : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<1x1xi32>
```
}];
let arguments = (
ins HLO_Tensor:$lhs,
Expand All @@ -2025,6 +2032,12 @@ def StableHLO_DotOp: StableHLO_Op<"dot",
return succeeded(mlir::verifyCompatibleShapes(l, r));
}
}];

// Use empty `` to prevent extra whitespace before precision config.
let assemblyFormat = [{
$lhs `,` $rhs `` custom<PrecisionConfig>($precision_config) attr-dict
`:` functional-type(operands, results)
}];
}

def StableHLO_DotGeneralOp: StableHLO_ShapedInterfaceOp<"dot_general", [NoSideEffect]> {
Expand Down Expand Up @@ -2676,6 +2689,12 @@ def StableHLO_RngOp : StableHLO_Op<"rng", [InferTensorTypeWithReify, AllElementT
to be scalar valued.

See https://www.tensorflow.org/xla/operation_semantics#rngnormal.

Example:

```mlir
%1 = stablehlo.rng %arg0, %arg1, %0, distribution = NORMAL : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
```
}];
let arguments = (ins
0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$a,
Expand All @@ -2694,6 +2713,11 @@ def StableHLO_RngOp : StableHLO_Op<"rng", [InferTensorTypeWithReify, AllElementT
return succeeded(::mlir::verifyCompatibleShapes(l, r));
}
}];

let assemblyFormat = [{
$a `,` $b `,` $shape `,` `distribution` `=` $rng_distribution
attr-dict `:` functional-type(operands, results)
}];
}

def StableHLO_RngBitGeneratorOp : StableHLO_Op<"rng_bit_generator", [NoSideEffect]> {
Expand All @@ -2704,6 +2728,12 @@ def StableHLO_RngBitGeneratorOp : StableHLO_Op<"rng_bit_generator", [NoSideEffec
(with the same shape as initial state) and the generated random data.

See https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator.

Example:

```mlir
%0, %1 = stablehlo.rng_bit_generator %arg0, algorithm = PHILOX : (tensor<3xui64>) -> (tensor<3xui64>, tensor<2x2xui32>)
```
}];
let arguments = (ins
StableHLO_RngAlgorithmAttr:$rng_algorithm,
Expand All @@ -2716,6 +2746,11 @@ def StableHLO_RngBitGeneratorOp : StableHLO_Op<"rng_bit_generator", [NoSideEffec
);

let hasVerifier = 1;

let assemblyFormat = [{
$initial_state `,` `algorithm` `=` $rng_algorithm attr-dict
`:` functional-type(operands, results)
}];
}

//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/tests/infer_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func.func @gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<1x5x2xi3
// CHECK-LABEL: @rng_normal
func.func @rng_normal(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<7xindex> {
%0 = "stablehlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64>
%1 = "stablehlo.rng"(%arg0, %arg1, %0) {rng_distribution = #stablehlo.rng_distribution<NORMAL>} : (tensor<f32>, tensor<f32>, tensor<1xi64>) -> tensor<7xf32>
%1 = "stablehlo.rng"(%arg0, %arg1, %0) {rng_distribution = #stablehlo<rng_distribution NORMAL>} : (tensor<f32>, tensor<f32>, tensor<1xi64>) -> tensor<7xf32>
%2 = "hlo_test_infer.get_return_type_components"(%1)
: (tensor<7xf32>) -> tensor<7xindex>
// CHECK: %2 = "hlo_test_infer.return_type_components"(%1) {dims0 = [7], element_type0 = f32} : (tensor<7xf32>) -> tensor<7xindex>
Expand All @@ -185,7 +185,7 @@ func.func @rng_normal(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<7xindex>
// CHECK-LABEL: func @rng_uniform
func.func @rng_uniform(%a: tensor<f32>, %b: tensor<f32>) -> tensor<2x3x5xindex> {
%0 = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64>
%1 = "stablehlo.rng"(%a, %b, %0) {rng_distribution = #stablehlo.rng_distribution<UNIFORM>} : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
%1 = "stablehlo.rng"(%a, %b, %0) {rng_distribution = #stablehlo<rng_distribution UNIFORM>} : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
%2 = "hlo_test_infer.get_return_type_components"(%1)
: (tensor<2x3x5xf32>) -> tensor<2x3x5xindex>
// CHECK: %2 = "hlo_test_infer.return_type_components"(%1) {dims0 = [2, 3, 5], element_type0 = f32} : (tensor<2x3x5xf32>) -> tensor<2x3x5xindex>
Expand Down
36 changes: 18 additions & 18 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,7 @@ func.func @rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<1
%4 = stablehlo.constant dense<[10, 12]> : tensor<2xui64>
%0 = stablehlo.constant dense<[10, 12]> : tensor<2xi32>
%1 = stablehlo.constant dense<3> : tensor<i32>
%2, %3 = "stablehlo.rng_bit_generator"(%4) {rng_algorithm = #stablehlo.rng_algorithm<DEFAULT>} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>)
%2, %3 = "stablehlo.rng_bit_generator"(%4) {rng_algorithm = #stablehlo<rng_algorithm DEFAULT>} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>)
func.return %2, %3 : tensor<2xui64>, tensor<10x12xui32>
}

Expand All @@ -1403,7 +1403,7 @@ func.func @rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<1
%0 = stablehlo.constant dense<[10, 12]> : tensor<2xi32>
%1 = stablehlo.constant dense<3> : tensor<i32>
// expected-error@+1 {{output state shape must match initial state shape. Got: 'tensor<2xui64>' and 'tensor<3xui64>'}}
%2, %3 = "stablehlo.rng_bit_generator"(%4) {rng_algorithm = #stablehlo.rng_algorithm<DEFAULT>} : (tensor<2xui64>) -> (tensor<3xui64>, tensor<10x12xui32>)
%2, %3 = "stablehlo.rng_bit_generator"(%4) {rng_algorithm = #stablehlo<rng_algorithm DEFAULT>} : (tensor<2xui64>) -> (tensor<3xui64>, tensor<10x12xui32>)
func.return %2, %3 : tensor<3xui64>, tensor<10x12xui32>
}

Expand All @@ -1412,23 +1412,23 @@ func.func @rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<1
// CHECK-LABEL: func @rng_normal
func.func @rng_normal(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<2x3x5xf32> {
%cst = "stablehlo.constant"() {value = dense<[2, 3, 5]> : tensor<3xi64>} : () -> tensor<3xi64>
%0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo.rng_distribution<NORMAL>}: (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
%0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo<rng_distribution NORMAL>}: (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
func.return %0 : tensor<2x3x5xf32>
}

// -----

// CHECK-LABEL: func @rng_normal_no_constant
func.func @rng_normal_no_constant(%a: tensor<f32>, %b: tensor<f32>, %shape: tensor<3xi64>) -> tensor<?x?x?xf32> {
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo.rng_distribution<NORMAL>}: (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo<rng_distribution NORMAL>}: (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<?x?x?xf32>
func.return %0 : tensor<?x?x?xf32>
}

// -----

// CHECK-LABEL: func @rng_normal_dynamic_dim
func.func @rng_normal_dynamic_dim(%a: tensor<f32>, %b: tensor<f32>, %shape: tensor<?xi64>) -> tensor<*xf32> {
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo.rng_distribution<NORMAL>}: (tensor<f32>, tensor<f32>, tensor<?xi64>) -> tensor<*xf32>
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo<rng_distribution NORMAL>}: (tensor<f32>, tensor<f32>, tensor<?xi64>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

Expand All @@ -1437,7 +1437,7 @@ func.func @rng_normal_dynamic_dim(%a: tensor<f32>, %b: tensor<f32>, %shape: tens
func.func @rng_normal_invalid_shape(%arg0: tensor<f32>, %arg1: tensor<f32>) {
%cst = "stablehlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64>
// expected-error @+1 {{inferred type(s) 'tensor<7xf32>' are incompatible with return type(s) of operation 'tensor<12xf32>'}}
%0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo.rng_distribution<NORMAL>}: (tensor<f32>, tensor<f32>, tensor<1xi64>) -> tensor<12xf32>
%0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo<rng_distribution NORMAL>}: (tensor<f32>, tensor<f32>, tensor<1xi64>) -> tensor<12xf32>
func.return
}

Expand All @@ -1446,7 +1446,7 @@ func.func @rng_normal_invalid_shape(%arg0: tensor<f32>, %arg1: tensor<f32>) {
func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> {
%shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64>
// expected-error@+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}}
%0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo.rng_distribution<NORMAL>}: (tensor<1xf32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
%0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo<rng_distribution NORMAL>}: (tensor<1xf32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
func.return %0 : tensor<2x3x5xf32>
}

Expand All @@ -1455,7 +1455,7 @@ func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor<f32>) -
func.func @rng_normal_invalid_sigma_rank(%mu: tensor<f32>, %sigma: tensor<1xf32>) -> tensor<2x3x5xf32> {
%shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64>
// expected-error@+1 {{#1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}}
%0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo.rng_distribution<NORMAL>}: (tensor<f32>, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32>
%0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo<rng_distribution NORMAL>}: (tensor<f32>, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32>
func.return %0 : tensor<2x3x5xf32>
}

Expand All @@ -1464,7 +1464,7 @@ func.func @rng_normal_invalid_sigma_rank(%mu: tensor<f32>, %sigma: tensor<1xf32>
func.func @rng_normal_invalid_shape_rank(%mu: tensor<f32>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> {
%shape = stablehlo.constant dense<[[2, 3, 5]]> : tensor<1x3xi64>
// expected-error@+1 {{operand #2 must be 1D tensor of index or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<1x3xi64>'}}
%0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo.rng_distribution<NORMAL>}: (tensor<f32>, tensor<f32>, tensor<1x3xi64>) -> tensor<2x3x5xf32>
%0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo<rng_distribution NORMAL>}: (tensor<f32>, tensor<f32>, tensor<1x3xi64>) -> tensor<2x3x5xf32>
func.return %0 : tensor<2x3x5xf32>
}

Expand All @@ -1473,7 +1473,7 @@ func.func @rng_normal_invalid_shape_rank(%mu: tensor<f32>, %sigma: tensor<f32>)
func.func @rng_normal_invalid_type(%arg0: tensor<complex<f32>>, %arg1: tensor<f32>) {
%cst = "stablehlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64>
// expected-error @+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<complex<f32>>'}}
%0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo.rng_distribution<NORMAL>}: (tensor<complex<f32>>, tensor<f32>, tensor<1xi64>) -> tensor<7xf32>
%0 = "stablehlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #stablehlo<rng_distribution NORMAL>}: (tensor<complex<f32>>, tensor<f32>, tensor<1xi64>) -> tensor<7xf32>
func.return
}

Expand All @@ -1482,31 +1482,31 @@ func.func @rng_normal_invalid_type(%arg0: tensor<complex<f32>>, %arg1: tensor<f3
// CHECK-LABEL: func @rng_uniform
func.func @rng_uniform(%a: tensor<f32>, %b: tensor<f32>) -> tensor<2x3x5xf32> {
%shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64>
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo.rng_distribution<UNIFORM>}: (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo<rng_distribution UNIFORM>}: (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
func.return %0 : tensor<2x3x5xf32>
}

// -----

// CHECK-LABEL: func @rng_uniform_no_constant
func.func @rng_uniform_no_constant(%a: tensor<f32>, %b: tensor<f32>, %shape: tensor<3xi64>) -> tensor<?x?x?xf32> {
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo.rng_distribution<UNIFORM>}: (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<?x?x?xf32>
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo<rng_distribution UNIFORM>}: (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<?x?x?xf32>
func.return %0 : tensor<?x?x?xf32>
}

// -----

// CHECK-LABEL: func @rng_uniform_dynamic_dim
func.func @rng_uniform_dynamic_dim(%a: tensor<f32>, %b: tensor<f32>, %shape: tensor<?xi64>) -> tensor<*xf32> {
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo.rng_distribution<UNIFORM>}: (tensor<f32>, tensor<f32>, tensor<?xi64>) -> tensor<*xf32>
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo<rng_distribution UNIFORM>}: (tensor<f32>, tensor<f32>, tensor<?xi64>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----

func.func @rng_uniform_invalid_shape(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<7xi64>) {
// expected-error @+1 {{inferred type(s) 'tensor<?x?x?x?x?x?x?xf32>' are incompatible with return type(s) of operation 'tensor<?xf32>'}}
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {rng_distribution = #stablehlo.rng_distribution<UNIFORM>}: (tensor<f32>, tensor<f32>, tensor<7xi64>) -> tensor<?xf32>
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {rng_distribution = #stablehlo<rng_distribution UNIFORM>}: (tensor<f32>, tensor<f32>, tensor<7xi64>) -> tensor<?xf32>
func.return
}

Expand All @@ -1515,7 +1515,7 @@ func.func @rng_uniform_invalid_shape(%arg0: tensor<f32>, %arg1: tensor<f32>, %ar
func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor<f32>) -> tensor<2x3x5xf32> {
%shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64>
// expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}}
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo.rng_distribution<UNIFORM>}: (tensor<1xf32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo<rng_distribution UNIFORM>}: (tensor<1xf32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
func.return %0 : tensor<2x3x5xf32>
}

Expand All @@ -1525,7 +1525,7 @@ func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor<f32>) -> ten
func.func @rng_uniform_invalid_b_rank(%a: tensor<f32>, %b: tensor<1xf32>) -> tensor<2x3x5xf32> {
%shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64>
// expected-error@+1 {{operand #1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}}
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo.rng_distribution<UNIFORM>}: (tensor<f32>, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32>
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo<rng_distribution UNIFORM>}: (tensor<f32>, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32>
func.return %0 : tensor<2x3x5xf32>
}

Expand All @@ -1534,7 +1534,7 @@ func.func @rng_uniform_invalid_b_rank(%a: tensor<f32>, %b: tensor<1xf32>) -> ten
func.func @rng_uniform_invalid_shape_rank(%a: tensor<f32>, %b: tensor<f32>) -> tensor<2x3x5xf32> {
%shape = stablehlo.constant dense<[[2, 3, 5]]> : tensor<1x3xi64>
// expected-error@+1 {{operand #2 must be 1D tensor of index or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<1x3xi64>'}}
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo.rng_distribution<UNIFORM>}: (tensor<f32>, tensor<f32>, tensor<1x3xi64>) -> tensor<2x3x5xf32>
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo<rng_distribution UNIFORM>}: (tensor<f32>, tensor<f32>, tensor<1x3xi64>) -> tensor<2x3x5xf32>
func.return %0 : tensor<2x3x5xf32>
}

Expand All @@ -1543,7 +1543,7 @@ func.func @rng_uniform_invalid_shape_rank(%a: tensor<f32>, %b: tensor<f32>) -> t
func.func @rng_uniform_invalid_type(%a: tensor<complex<f32>>, %b: tensor<f32>) -> tensor<2x3x5xf32> {
%shape = stablehlo.constant dense<[2, 3, 5]> : tensor<3xi64>
// expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<complex<f32>>'}}
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo.rng_distribution<UNIFORM>}: (tensor<complex<f32>>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo<rng_distribution UNIFORM>}: (tensor<complex<f32>>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
func.return %0 : tensor<2x3x5xf32>
}

Expand Down
Loading

0 comments on commit 82df29a

Please sign in to comment.