-
Notifications
You must be signed in to change notification settings - Fork 72
/
Copy pathhlo_ops.td
3946 lines (3321 loc) · 128 KB
/
hlo_ops.td
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/* Copyright 2019 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation definition file for MHLO ops.
#ifndef MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS
#define MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS
include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/OpBase.td"
include "mhlo/IR/hlo_utils.td"
include "mhlo/IR/hlo_ops_common.td"
class MHLO_Op<string mnemonic, list<Trait> traits> :
Op<MHLO_Dialect, mnemonic, traits> {
// Whether this operation has a custom conversion to HLO or not.
bit hasCustomHLOConverter = 0b0;
let extraClassDeclaration = [{
// Relax the strict default implementation with one that allows
// for StableHLO-specific differences.
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return mlir::hlo::isCompatibleForHloTypeInference(l, r);
}
}];
}
class MHLO_ShapedInterfaceOp<string mnemonic, list<Trait> traits> :
MHLO_Op<mnemonic, traits # [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapes"]>]> {
let extraClassDeclaration = [{
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return mlir::hlo::isCompatibleForHloTypeInference(l, r);
}
}];
}
//===----------------------------------------------------------------------===//
// MHLO nullary op definitions.
//===----------------------------------------------------------------------===//
def MHLO_ConstantOp : MHLO_Op<"constant",
[ConstantLike, Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Constant operation";
let description = [{
Produces an `output` tensor from a constant `value`.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
Example:
```mlir
%output = mhlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
```
}];
let arguments = (ins
ElementsAttr:$value
);
let results = (outs
MHLO_StaticShapeTensor:$output
);
let builders = [
OpBuilder<(ins "Attribute":$value)>];
let hasCustomAssemblyFormat = 1;
// Constant has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
let hasFolder = 1;
let extraClassDeclaration = [{
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}
def MHLO_IotaOp : MHLO_Op<"iota", [Pure]> {
let summary = "Iota operation";
let description = [{
Fills an `output` tensor with values in increasing order starting from zero
along the `iota_dimension` dimension.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#iota
Example:
```mlir
%output = mhlo.iota dim = 0 : tensor<4x5xi32>
```
}];
let arguments = (ins
ConfinedAttr<I64Attr, [IntNonNegative]>:$iota_dimension
);
let results = (outs MHLO_StaticShapeIntFpOrComplexTensor:$output);
// TODO(b/130357376): Iota has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
def MHLO_DynamicIotaOp: MHLO_ShapedInterfaceOp<"dynamic_iota", [Pure]> {
let summary = "DynamicIota operation";
let description = [{
This operation is functionally identical to
[iota](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#iota)
op, but the result shape is specified dynamically via `output_shape`.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dynamic_iota
Example:
```mlir
%0 = mhlo.dynamic_iota %arg0, dim = 0 : (tensor<1xindex>) -> tensor<4xi32>
```
}];
let arguments = (ins
MHLO_DimensionTensor:$output_shape,
ConfinedAttr<I64Attr, [IntNonNegative]>:$iota_dimension
);
let results = (outs MHLO_Tensor:$result);
let hasCanonicalizer = 1;
// Cannot be exported to legacy formats.
let hasCustomHLOConverter = 1;
}
def MHLO_CreateTokenOp : MHLO_Op<"create_token", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "CreateToken operation";
let description = [{
This operation is on its way out of StableHLO, so it is not included in
the specification: https://github.com/openxla/stablehlo/issues/3.
Informally, this operation does the same thing as AfterAllOp with 0 inputs:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#after_all
Example:
```mlir
%output = mhlo.create_token : !mhlo.token
```
}];
let results = (outs MHLO_Token:$output);
let assemblyFormat = "attr-dict `:` type(results)";
}
//===----------------------------------------------------------------------===//
// MHLO unary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
class MHLO_UnaryElementwiseOp<string mnemonic, list<Trait> traits,
Type OperandType, Type ResultType = OperandType> : MHLO_Op<mnemonic, traits # [Elementwise,
InferShapedTypeOpInterface, SameOperandsAndResultShape]> {
let arguments = (ins OperandType:$operand);
let results = (outs ResultType:$result);
let extraClassDeclaration = [{
LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(),
operands.front(),
&reifiedReturnShapes);
}
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return mlir::hlo::isCompatibleForHloTypeInference(l, r);
}
}];
let assemblyFormat = [{
$operand attr-dict
`:` custom<SameOperandsAndResultType>(type($operand), type($result))
}];
}
// Abs supports complex to real, so element type is not guaranteed to match.
def MHLO_AbsOp: MHLO_UnaryElementwiseOp<"abs",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>],
RankedTensorOf<[MHLO_SInt, MHLO_Float, MHLO_Complex, MHLO_QuantizedInt]>,
RankedTensorOf<[MHLO_SInt, MHLO_Float, MHLO_QuantizedInt]>> {
let summary = "Abs operation";
let description = [{
Performs element-wise abs operation on `operand` tensor and produces a
`result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#abs
Example:
```mlir
%result = mhlo.abs %operand : tensor<3xi32>
```
}];
let hasFolder = 1;
}
def MHLO_CbrtOp: MHLO_UnaryElementwiseOp<"cbrt",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpComplexOrQuantizedIntTensor> {
let summary = "Cbrt operation";
let description = [{
Performs element-wise cubic root operation on `operand` tensor and produces
a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cbrt
Example:
```mlir
%result = mhlo.cbrt %operand : tensor<4xf32>
```
}];
}
def MHLO_CeilOp: MHLO_UnaryElementwiseOp<"ceil",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrQuantizedIntTensor> {
let summary = "Ceil operation";
let description = [{
Performs element-wise ceil of `operand` tensor and produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#ceil
Example:
```mlir
%result = mhlo.ceil %operand : tensor<5xf32>
```
}];
}
def MHLO_ConvertOp : MHLO_UnaryElementwiseOp<"convert",
[Pure, SameOperandsAndResultShape], MHLO_Tensor> {
let summary = "Convert operation";
let description = [{
Performs an element-wise conversion from one element type to another on
`operand` tensor and produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convert
Example:
```mlir
%result = mhlo.convert %operand : (tensor<3xi32>) -> tensor<3xcomplex<f32>>
```
}];
let builders = [
OpBuilder<(ins "Value":$operand, "Type":$result_element_ty)>];
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasCustomHLOConverter = 1;
}
def MHLO_ClzOp: MHLO_UnaryElementwiseOp<"count_leading_zeros",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntTensor> {
let summary = "Clz operation";
let description = [{
Performs element-wise count of the number of leading zero bits in the
`operand` tensor and produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#count_leading_zeros
Example:
```mlir
%result = mhlo.count_leading_zeros %operand : tensor<2x2xi8>
```
}];
}
def MHLO_CosineOp: MHLO_UnaryElementwiseOp<"cosine",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpComplexOrQuantizedIntTensor> {
let summary = "Cosine operation";
let description = [{
Performs element-wise cosine operation on `operand` tensor and produces a
`result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cosine
Example:
```mlir
%result = mhlo.cosine %operand : tensor<2xf32>
```
}];
let hasFolder = 1;
let hasCustomHLOConverter = 1;
}
def MHLO_ErfOp: MHLO_UnaryElementwiseOp<"erf",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpTensor> {
let summary = "Erf operation";
let description = [{
Performs element-wise erf operation on `operand` tensor and produces a
`result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#erf
Example:
```mlir
%result = mhlo.erf %operand : tensor<2x2xf32>
```
}];
let hasFolder = 1;
}
def MHLO_ExpOp: MHLO_UnaryElementwiseOp<"exponential",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpComplexOrQuantizedIntTensor> {
let summary = "Exp operation";
let description = [{
Performs element-wise exponential operation on `operand` tensor and produces
a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential
Example:
```mlir
%result = mhlo.exponential %operand : tensor<2x2xf64>
```
}];
let arguments = (ins MHLO_FpComplexOrQuantizedIntTensor:$operand,
DefaultValuedOptionalAttr<MHLO_ResultAccuracyAttr, "::mlir::mhlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
let results = (outs MHLO_FpComplexOrQuantizedIntTensor:$result);
let hasVerifier = 1;
let hasFolder = 1;
}
def MHLO_Expm1Op: MHLO_UnaryElementwiseOp<"exponential_minus_one",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpComplexOrQuantizedIntTensor> {
let summary = "Expm1 operation";
let description = [{
Performs element-wise exponential minus one operation on `operand` tensor
and produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential_minus_one
Example:
```mlir
%result = mhlo.exponential_minus_one %operand : tensor<2xf32>
```
}];
}
def MHLO_FloorOp: MHLO_UnaryElementwiseOp<"floor",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrQuantizedIntTensor> {
let summary = "Floor operation";
let description = [{
Performs element-wise floor of `operand` tensor and produces a `result`
tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#floor
Example:
```mlir
%result = mhlo.floor %operand : tensor<2xf32>
```
}];
}
def MHLO_ImagOp: MHLO_UnaryElementwiseOp<"imag",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>],
MHLO_FpOrComplexTensor, MHLO_FpTensor> {
let summary = "Imag operation";
let description = [{
Extracts the imaginary part, element-wise, from the `operand` and produces a
`result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#imag
Example:
```mlir
%result = mhlo.imag %operand : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
```
}];
let hasFolder = 1;
}
def MHLO_IsFiniteOp: MHLO_UnaryElementwiseOp<"is_finite", [Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>], MHLO_Tensor> {
let summary = "IsFinite operation";
let description = [{
Performs element-wise check whether the value in `x` is finite (i.e. is
neither +Inf, -Inf, nor NaN) and produces a `y` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#is_finite
Example:
```mlir
%y = mhlo.is_finite %x : (tensor<7xf32>) -> tensor<7xi1>
```
}];
let arguments = (ins MHLO_FpTensor:$x);
let results = (outs MHLO_PredTensor:$y);
let assemblyFormat = [{
$x attr-dict `:` functional-type(operands, results)
}];
}
def MHLO_LogOp: MHLO_UnaryElementwiseOp<"log",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpComplexOrQuantizedIntTensor> {
let summary = "Log operation";
let description = [{
Performs element-wise logarithm operation on `operand` tensor and produces a
`result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log
Example:
```mlir
%result = mhlo.log %operand : tensor<2x2xf64>
```
}];
let hasFolder = 1;
}
def MHLO_Log1pOp: MHLO_UnaryElementwiseOp<"log_plus_one",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpComplexOrQuantizedIntTensor> {
let summary = "Log1p operation";
let description = [{
Performs element-wise logarithm plus one operation on `operand` tensor and
produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#log_plus_one
Example:
```mlir
%result = mhlo.log_plus_one %operand : tensor<6xf32>
```
}];
}
def MHLO_LogisticOp: MHLO_UnaryElementwiseOp<"logistic",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpComplexOrQuantizedIntTensor> {
let summary = "Logistic operation";
let description = [{
Performs element-wise logistic operation on `operand` tensor and produces a
`result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#logistic
Example:
```mlir
%result = mhlo.logistic %operand : tensor<2x2xf32>
```
}];
let hasFolder = 1;
}
def MHLO_NotOp: MHLO_UnaryElementwiseOp<"not",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_PredOrIntTensor> {
let summary = "Not operation";
let description = [{
Performs element-wise NOT of tensor `operand` of type integer and produces
a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#not
Example:
```mlir
%result = mhlo.not %operand : tensor<5x3x1xi1>
```
}];
let hasFolder = 1;
}
def MHLO_NegOp: MHLO_UnaryElementwiseOp<"negate",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntFpOrComplexOrQuantizedIntTensor> {
let summary = "Neg operation";
let description = [{
Performs element-wise negation of `operand` tensor and produces a `result`
tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#negate
Example:
```mlir
%result = mhlo.negate %operand : tensor<2x3xi32>
```
}];
let hasFolder = 1;
}
def MHLO_PopulationCountOp: MHLO_UnaryElementwiseOp<"popcnt",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntTensor> {
let summary = "PopulationCount operation";
let description = [{
Performs element-wise count of the number of bits set in the `operand`
tensor and produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#popcnt
Example:
```mlir
%result = mhlo.popcnt %operand : tensor<4xi8>
```
}];
}
def MHLO_RealOp: MHLO_UnaryElementwiseOp<"real",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>],
MHLO_FpOrComplexTensor, MHLO_FpTensor> {
let summary = "Real operation";
let description = [{
Extracts the real part, element-wise, from the `operand` and produces a
`result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#real
Example:
```mlir
%result = mhlo.real %operand : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
```
}];
let hasFolder = 1;
}
def MHLO_RoundOp: MHLO_UnaryElementwiseOp<"round_nearest_afz",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpTensor> {
let summary = "Round operation";
let description = [{
Performs element-wise rounding towards the nearest integer, breaking ties
away from zero, on the `operand` tensor and produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_afz
Example:
```mlir
%result = mhlo.round_nearest_afz %operand : tensor<5xf32>
```
}];
let hasFolder = 1;
}
def MHLO_RoundNearestEvenOp: MHLO_UnaryElementwiseOp<"round_nearest_even",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpTensor> {
let summary = "RoundNearestEven operation";
let description = [{
Performs element-wise rounding towards the nearest integer, breaking ties
towards the even integer, on the `operand` tensor and produces a `result`
tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#round_nearest_even
Example:
```mlir
%result = mhlo.round_nearest_even %operand : tensor<5xf32>
```
}];
let hasFolder = 1;
}
def MHLO_RsqrtOp: MHLO_UnaryElementwiseOp<"rsqrt",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpComplexOrQuantizedIntTensor> {
let summary = "Rsqrt operation";
let description = [{
Performs element-wise reciprocal square root operation on `operand` tensor
and produces a `result` tensor, implementing the `rSqrt` operation from the
IEEE-754 specification.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#rsqrt
Example:
```mlir
%result = mhlo.rsqrt %operand : tensor<2x2xf32>
```
}];
let hasFolder = 1;
}
def MHLO_SignOp: MHLO_UnaryElementwiseOp<"sign",
[Pure, HLO_CompatibleOperandsAndResultType],
RankedTensorOf<[MHLO_SInt, MHLO_Float, MHLO_Complex, HLO_QuantizedInt]>> {
let summary = "Sign operation";
let description = [{
Returns the sign of the `operand` element-wise and produces a `result`
tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sign
Example:
```mlir
%result = mhlo.sign %operand : tensor<7xf32>
```
}];
let hasFolder = 1;
}
def MHLO_SineOp: MHLO_UnaryElementwiseOp<"sine",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpComplexOrQuantizedIntTensor> {
let summary = "Sine operation";
let description = [{
Performs element-wise sine operation on `operand` tensor and produces a
`result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sine
Example:
```mlir
%result = mhlo.sine %operand : tensor<2xf32>
```
}];
let hasFolder = 1;
let hasCustomHLOConverter = 1;
}
def MHLO_TanOp: MHLO_UnaryElementwiseOp<"tan",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrComplexTensor> {
let summary = "Tan operation";
let description = [{
This operation is a work in progress, so it is not yet included in
the specification: https://github.com/openxla/stablehlo/issues/954.
Informally, this operation returns `Tan(operand)` element-wise.
Example:
```mlir
%0 = mhlo.tan %arg0 : tensor<2xf32>
```
}];
let hasFolder = 1;
let hasCustomHLOConverter = 1;
}
def MHLO_SqrtOp: MHLO_UnaryElementwiseOp<"sqrt",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpComplexOrQuantizedIntTensor> {
let summary = "Sqrt operation";
let description = [{
Performs element-wise square root operation on `operand` tensor and produces
a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sqrt
Example:
```mlir
%result = mhlo.sqrt %operand : tensor<2x2xf32>
```
}];
let hasFolder = 1;
}
def MHLO_TanhOp: MHLO_UnaryElementwiseOp<"tanh",
[Pure, HLO_CompatibleOperandsAndResultType],
MHLO_FpComplexOrQuantizedIntTensor> {
let summary = "Tanh operation";
let description = [{
Performs element-wise hyperbolic tangent operation on `operand` tensor and
produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#tanh
Example:
```mlir
%result = mhlo.tanh %operand : tensor<2xf32>
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// MHLO binary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
class MHLO_BinaryElementwiseOp<string mnemonic, list<Trait> traits,
Type OperandType = MHLO_Tensor, Type ResultType = OperandType> :
MHLO_Op<mnemonic, traits # [InferShapedTypeOpInterface,
SameOperandsAndResultShape, Elementwise]> {
let arguments = (ins
OperandType:$lhs,
OperandType:$rhs
);
let extraClassDeclaration = [{
LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(),
operands.front(),
&reifiedReturnShapes);
}
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return mlir::hlo::isCompatibleForHloTypeInference(l, r);
}
}];
let results = (outs ResultType:$result);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict
`:` custom<SameOperandsAndResultType>(type($lhs), type($rhs), type($result))
}];
}
def MHLO_AddOp : MHLO_BinaryElementwiseOp<"add",
[Commutative, Pure, HLO_CompatibleOperandsAndResultType]> {
let summary = "Add operation";
let description = [{
Performs element-wise addition of two tensors `lhs` and `rhs` and produces a
`result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#add
Example:
```mlir
%result = mhlo.add %lhs, %rhs : tensor<2x2xi32>
```
}];
let hasFolder = 1;
}
def MHLO_Atan2Op : MHLO_BinaryElementwiseOp<"atan2",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpComplexOrQuantizedIntTensor> {
let summary = "Atan2 operation";
let description = [{
Performs element-wise atan2 operation on `lhs` and `rhs` tensor and produces
a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#atan2
Example:
```mlir
%result = mhlo.atan2 %lhs, %rhs : tensor<3xf32>
```
}];
}
def MHLO_ComplexOp: MHLO_BinaryElementwiseOp<"complex", [Pure,
SameOperandsElementType, SameOperandsAndResultShape,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Complex operation";
let description = [{
Performs element-wise conversion to a complex value from a pair of real and
imaginary values, `lhs` and `rhs`, and produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#complex
Example:
```mlir
%result = mhlo.complex %lhs, %rhs : tensor<2xcomplex<f32>>
```
}];
let arguments = (ins MHLO_Fp32Or64Tensor:$lhs, MHLO_Fp32Or64Tensor:$rhs);
let results = (outs MHLO_ComplexTensor:$result);
let hasFolder = 1;
let assemblyFormat = [{
operands attr-dict
`:` custom<ComplexOpType>(type($lhs), type($rhs), type($result))
}];
}
def MHLO_DivOp : MHLO_BinaryElementwiseOp<"divide",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntFpOrComplexOrQuantizedIntTensor> {
let summary = "Div operation";
let description = [{
Performs element-wise division of dividend `lhs` and divisor `rhs` tensors
and produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#divide
Example:
```mlir
%result = mhlo.divide %lhs, %rhs : tensor<4xf32>
```
}];
let hasFolder = 1;
}
def MHLO_MaxOp : MHLO_BinaryElementwiseOp<"maximum",
[Commutative, Pure, HLO_CompatibleOperandsAndResultType]> {
let summary = "Max operation";
let description = [{
Performs element-wise max operation on tensors `lhs` and `rhs` and produces
a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#maximum
Example:
```mlir
%result = mhlo.maximum %lhs, %rhs : tensor<4xf32>
```
}];
let hasFolder = 1;
}
def MHLO_MinOp : MHLO_BinaryElementwiseOp<"minimum",
[Commutative, Pure, HLO_CompatibleOperandsAndResultType]> {
let summary = "Min operation";
let description = [{
Performs element-wise min operation on tensors `lhs` and `rhs` and produces a
`result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#minimum
Example:
```mlir
%result = mhlo.minimum %lhs, %rhs : tensor<4xf32>
```
}];
let hasFolder = 1;
}
def MHLO_MulOp : MHLO_BinaryElementwiseOp<"multiply",
[Commutative, Pure, HLO_CompatibleOperandsAndResultType]> {
let summary = "Mul operation";
let description = [{
Performs element-wise product of two tensors `lhs` and `rhs` and produces a
`result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#multiply
Example:
```mlir
%result = mhlo.multiply %lhs, %rhs : tensor<2xi32>
```
}];
let hasFolder = 1;
}
def MHLO_PowOp : MHLO_BinaryElementwiseOp<"power",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntFpOrComplexOrQuantizedIntTensor> {
let summary = "Pow operation";
let description = [{
Performs element-wise exponentiation of `lhs` tensor by `rhs` tensor and
produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#power
Example:
```mlir
%result = mhlo.power %lhs, %rhs : tensor<6xf32>
```
}];
}
def MHLO_RemOp : MHLO_BinaryElementwiseOp<"remainder",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntFpOrComplexOrQuantizedIntTensor> {
let summary = "Rem operation";
let description = [{
Performs element-wise remainder of dividend `lhs` and divisor `rhs` tensors
and produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#remainder
Example:
```mlir
%result = mhlo.remainder %lhs, %rhs : tensor<4xi64>
```
}];
let hasFolder = 1;
}
def MHLO_ShiftLeftOp : MHLO_BinaryElementwiseOp<"shift_left",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntTensor> {
let summary = "ShiftLeft operation";
let description = [{
Performs element-wise left-shift operation on the `lhs` tensor by `rhs`
number of bits and produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_left
Example:
```mlir
%result = mhlo.shift_left %lhs, %rhs : tensor<6xi8>
```
}];
}
def MHLO_ShiftRightArithmeticOp : MHLO_BinaryElementwiseOp<"shift_right_arithmetic",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntTensor> {
let summary = "ShiftRightArithmetic operation";
let description = [{
Performs element-wise arithmetic right-shift operation on the `lhs` tensor
by `rhs` number of bits and produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_right_arithmetic
Example:
```mlir
%result = mhlo.shift_right_arithmetic %lhs, %rhs : tensor<6xi8>
```
}];
}
def MHLO_ShiftRightLogicalOp : MHLO_BinaryElementwiseOp<"shift_right_logical",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntTensor> {
let summary = "ShiftRightLogical operation";
let description = [{
Performs element-wise logical right-shift operation on the `lhs` tensor by
`rhs` number of bits and produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_right_logical
Example:
```mlir
%result = mhlo.shift_right_logical %lhs, %rhs : tensor<6xi8>
```
}];
}
def MHLO_SubtractOp : MHLO_BinaryElementwiseOp<"subtract",
[Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntFpOrComplexOrQuantizedIntTensor> {
let summary = "Subtract operation";
let description = [{
Performs element-wise subtraction of two tensors `lhs` and `rhs` and
produces a `result` tensor.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#subtract
Example:
```mlir
%result = mhlo.subtract %lhs, %rhs : tensor<2xi32>
```
}];
let hasFolder = 1;
let hasCustomHLOConverter = 1;
}
def MHLO_StochasticConvertOp : MHLO_Op<"stochastic_convert",
[Pure, Elementwise, AllShapesMatch<["operand", "random", "result"]>]> {
let summary = "StochasticConvert operation";
let description = [{
This operation is a work in progress, so it is not yet included in
the specification: https://github.com/openxla/stablehlo/issues/295.
Informally, this operation performs element-wise conversion of values from
a bigger type to a smaller one with stochastic rounding using the random
number passed in.
}];
let arguments = (ins MHLO_FpTensor:$operand, RankedTensorOf<[MHLO_UInt]>:$random);
let results = (outs MHLO_Tensor:$result);
let hasCustomHLOConverter = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
// MHLO binary logical elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
class MHLO_BinaryBiwiseOrLogicalElementwiseOp<string mnemonic> :
MHLO_BinaryElementwiseOp<mnemonic,
[Commutative, Pure, HLO_CompatibleOperandsAndResultType]> {
let arguments = (ins
MHLO_PredOrIntTensor:$lhs,
MHLO_PredOrIntTensor:$rhs
);