-
Notifications
You must be signed in to change notification settings - Fork 72
/
lhlo_ops.td
1683 lines (1457 loc) · 58.3 KB
/
lhlo_ops.td
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation definition file for LMHLO, the "late" MHLO variant of
// the dialect, which operates on buffers instead of tensors.
//
// This file largely overlaps with hlo_ops.td at a logical level. It's tempting
// to merge these two files together, but we need to consider the following
// obstacles:
// * We need to have a common representation for arguments. That is to say,
// HLO_Array<X> translates to MHLO_Tensor<X> in HLO dialect, and
// Arg<LHLO_Buffer<X>, "", [Mem(Read|Write)]> in LHLO. Array types within
// tuples also need to be transformed.
// * As of now, TableGen's dag functions are not sufficient to accomplish the
// one above.
// * Traits aren't identical, but need to be copied. For example,
// SameOperandAndResultType in HLO corresponds to SameTypeOperands in LHLO.
// * Also, currently HLO describes the API in XLA's client side, not service
// side. LHLO aims for the service side.
#ifndef LHLO_OPS
#define LHLO_OPS
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/CopyOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "lhlo/IR/lhlo_dialect.td"
include "lhlo/IR/lhlo_ops_base.td"
include "lhlo/IR/lhlo_ops_structs.td"
include "lhlo/IR/lhlo_structured_interface.td"
//===----------------------------------------------------------------------===//
// LMHLO nullary op definitions.
//===----------------------------------------------------------------------===//
class LHLO_Op<string mnemonic, list<Trait> traits> :
Op<LHLO_Dialect, mnemonic,
!listconcat([MemoryEffects<[MemRead, MemWrite]>,
LmhloStructuredInterface], traits)>;
def LHLO_ConstantOp : LHLO_Op<"constant", []> {
let summary = "Constant operator";
let description = [{
Represents a constant value.
}];
let arguments = (ins
ElementsAttr:$value,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
let hasCanonicalizer = 1;
}
def LHLO_IotaOp : LHLO_Op<"iota", []> {
let summary = "Iota operator";
let description = [{
Creates a rank 1 array of values starting at zero and incrementing by one.
}];
let arguments = (ins I64Attr:$iota_dimension,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
}
//===----------------------------------------------------------------------===//
// LMHLO unary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
class LHLO_UnaryElementwiseOp<string mnemonic,
Type BufferType = LHLO_Buffer,
list<Trait> traits = [SameTypeOperands, Elementwise]>
: LHLO_Op<mnemonic, traits> {
let arguments = (ins Arg<BufferType, "", [MemRead]>:$input,
Arg<BufferType, "", [MemWrite]>:$output);
}
// Abs supports complex to real, so element type is not guaranteed to match.
def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs", LHLO_Buffer, [SameOperandsShape]> {
let summary = "Absolute value operator";
let description = [{
Returns `abs(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
let hasVerifier = 1;
}
// TODO(timshen): add a custom verifier.
def LHLO_BitcastConvertOp:
LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]> {
let summary = "BitcastConvert operator";
let description = [{
Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast
operation from a data shape to a target shape. The dimensions must match,
and the conversion is an element-wise one. Bitcast is implemented as a
low-level cast, so machines with different floating-point representations
will give different results.
See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype.
}];
}
def LHLO_CbrtOp: LHLO_UnaryElementwiseOp<"cbrt", LHLO_FpBuffer> {
let summary = "Cubic root operator";
let description = [{
Returns element-wise cubic root of the operand.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer> {
let summary = "Ceil operator";
let description = [{
Returns `Ceil(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer> {
let summary = "Count-leading-zeros (Clz) operator";
let description = [{
Returns the number of leading zeros in each operand element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
// TODO(timshen): add a custom verifier.
def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert", LHLO_Buffer, [SameOperandsShape]> {
let summary = "Convert operator";
let description = [{
Performs element-wise conversion of values from one type to another, e.g.
float to int.
See https://www.tensorflow.org/xla/operation_semantics#convertelementtype.
}];
}
def LHLO_CosineOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer> {
let summary = "Cos operator";
let description = [{
Returns `Cos(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_TanOp: LHLO_UnaryElementwiseOp<"tan", LHLO_FpOrComplexBuffer> {
let summary = "Tan operator";
let description = [{
Returns `Tan(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential", LHLO_FpOrComplexBuffer> {
let summary = "Exponential operator";
let description = [{
Returns `e^(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_Expm1Op: LHLO_UnaryElementwiseOp<"exponential_minus_one", LHLO_FpOrComplexBuffer> {
let summary = "Exponential minus one operator";
let description = [{
Returns `e^(operand) - 1` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_FloorOp: LHLO_UnaryElementwiseOp<"floor", LHLO_FpBuffer> {
let summary = "Floor operator";
let description = [{
Returns `Floor(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]> {
let summary = "Imag operator";
let description = [{
Returns `Imag(operand)` element-wise.
}];
let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
}
def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]> {
let summary = "IsFinite operator";
let description = [{
Tests whether each element of operand is finite, i.e., is not positive or
negative infinity, and is not NaN. Returns a tensor of 1-bit integers with
the same shape as the input, where each element is nonzero (i.e. true) if
and only if the corresponding input element is finite.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
let arguments = (ins Arg<LHLO_FpBuffer, "", [MemRead]>:$input,
Arg<LHLO_PredBuffer, "", [MemWrite]>:$output);
}
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer> {
let summary = "Logarithm operator";
let description = [{
Returns `log(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_LogisticOp : LHLO_UnaryElementwiseOp<"logistic", LHLO_FpOrComplexBuffer> {
let summary = "Logistic operator";
let description = [{
Returns `logistic(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer> {
let summary = "Log1p operator";
let description = [{
Returns `log(operand+1)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate"> {
let summary = "Negation operator";
let description = [{
Returns `-operand` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_NotOp: LHLO_UnaryElementwiseOp<"not", LHLO_PredOrIntBuffer> {
let summary = "Not operator";
let description = [{
Returns `!operand` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_PopulationCountOp: LHLO_UnaryElementwiseOp<"popcnt", LHLO_IntBuffer> {
let summary = "PopulationCount operator";
let description = [{
Returns the number of bits set in each operand element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]> {
let summary = "Real operator";
let description = [{
Returns `Real(operand)` element-wise.
}];
let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
}
def LHLO_RoundOp: LHLO_UnaryElementwiseOp<"round_nearest_afz", LHLO_FpBuffer> {
let summary = "Round operator";
let description = [{
Returns `Round(operand)` element-wise, rounding to nearest integer with
half-way cases rounding away from zero.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_RoundNearestEvenOp: LHLO_UnaryElementwiseOp<"round_nearest_even", LHLO_FpBuffer> {
let summary = "Round nearest even operator";
let description = [{
Returns `Round(operand)` element-wise, rounding to nearest integer with
half-way cases rounding towards even numbers.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt", LHLO_FpOrComplexBuffer> {
let summary = "Reciprocal Square-root operator";
let description = [{
Returns `1.0 / sqrt(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt", LHLO_FpOrComplexBuffer> {
let summary = "Square-root operator";
let description = [{
Returns `sqrt(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign"> {
let summary = "Sign operator";
let description = [{
Returns `sign(operand)` element-wise, where
```
sign(x) = -1 : x < 0
= -0 : x = -0
= NaN : x = NaN
= +0 : x = +0
= 1 : x > 0
```
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_SineOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer> {
let summary = "Sin operator";
let description = [{
Returns `Sin(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer> {
let summary = "Tanh operator";
let description = [{
Returns `tanh(operand)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
//===----------------------------------------------------------------------===//
// LMHLO binary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
class LHLO_BinaryElementwiseOp<string mnemonic, Type BufferType = LHLO_Buffer,
list<Trait> traits = [SameTypeOperands, Elementwise]> :
LHLO_Op<mnemonic, traits> {
let arguments = (ins
Arg<BufferType, "", [MemRead]>:$lhs,
Arg<BufferType, "", [MemRead]>:$rhs,
Arg<BufferType, "", [MemWrite]>:$out,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
}
def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add"> {
let summary = "Addition operator";
let description = [{
Returns `lhs + rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", LHLO_PredOrIntBuffer> {
let summary = "Logical and";
let description = [{
Returns `logical_and(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def LHLO_Atan2Op : LHLO_BinaryElementwiseOp<"atan2", LHLO_FpOrComplexBuffer> {
let summary = "Atan2 operator";
let description = [{
Returns `atan2(lhs/rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]> {
let summary = "Complex operator";
let description = [{
Performs element-wise conversion of a pair of real and imaginary values to
a complex value.
}];
let arguments = (ins
Arg<LHLO_FpBuffer, "", [MemRead]>:$lhs,
Arg<LHLO_FpBuffer, "", [MemRead]>:$rhs,
Arg<LHLO_ComplexBuffer, "", [MemWrite]>:$output,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
}
def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide"> {
let summary = "Division operator";
let description = [{
Returns `lhs / rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum"> {
let summary = "Maximum operator";
let description = [{
Returns `max(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum"> {
let summary = "Minimum operator";
let description = [{
Returns `min(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply"> {
let summary = "Multiplication operator";
let description = [{
Returns `lhs * rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def LHLO_OrOp : LHLO_BinaryElementwiseOp<"or", LHLO_PredOrIntBuffer> {
let summary = "Logical or";
let description = [{
Returns `logical_or(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def LHLO_PowOp : LHLO_BinaryElementwiseOp<"power"> {
let summary = "Power operator";
let description = [{
Returns `lhs ^ rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def LHLO_RemOp : LHLO_BinaryElementwiseOp<"remainder", LHLO_IntOrFpBuffer> {
let summary = "Remainder operator";
let description = [{
Returns `lhs % rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def LHLO_ShiftLeftOp : LHLO_BinaryElementwiseOp<"shift_left", LHLO_IntBuffer> {
let summary = "Shift Left operator";
let description = [{
Returns `lhs << rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def LHLO_ShiftRightArithmeticOp : LHLO_BinaryElementwiseOp<"shift_right_arithmetic", LHLO_IntBuffer> {
let summary = "Shift right arithmetic operator";
let description = [{
Returns arithmetic `lhs >> rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def LHLO_ShiftRightLogicalOp : LHLO_BinaryElementwiseOp<"shift_right_logical", LHLO_IntBuffer> {
let summary = "Shift right logical operator";
let description = [{
Returns logical `lhs >> rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def LHLO_SubtractOp : LHLO_BinaryElementwiseOp<"subtract"> {
let summary = "Subtraction operator";
let description = [{
Returns `lhs - rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer> {
let summary = "Logical xor";
let description = [{
Returns `logical_xor(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
//===----------------------------------------------------------------------===//
// LMHLO control flow op definitions.
//===----------------------------------------------------------------------===//
// TODO(b/139813999): specify required function signature in a type-safe way.
//
// The region `body` may return lmhlo.TerminatorOp or mhlo.ReturnOp. We are
// moving towards mhlo.ReturnOp, but some code that needs cleanup still assumes lmhlo.TerminatorOp.
// TODO(timshen): cleanup lmhlo.TerminatorOp.
def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]> {
let summary = "Reduce operator";
let description = [{
Returns the result of executing a reduction function on one or more arrays
in parallel.
See https://www.tensorflow.org/xla/operation_semantics#reduce.
}];
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$inputs,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out,
I64ElementsAttr:$dimensions
);
let regions = (region SizedRegion<1>:$body);
let hasCanonicalizer = 1;
}
def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [SameVariadicOperandSize]> {
let summary = "ReduceWindow operator";
let description = [{
Returns the result of executing a reduction function over all elements in
each window of one or more arrays in parallel.
See https://www.tensorflow.org/xla/operation_semantics#reducewindow.
}];
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$inputs,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out,
I64ElementsAttr:$window_dimensions,
// If strides or dilations attributes are missing then the default value is
// one for each of the input dimensions. Similarly, padding values are zero
// for both low and high in each of the dimensions, if not specified.
OptionalAttr<I64ElementsAttr>:$window_strides,
OptionalAttr<I64ElementsAttr>:$base_dilations,
OptionalAttr<I64ElementsAttr>:$window_dilations,
OptionalAttr<I64ElementsAttr>:$padding
);
let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;
}
// TODO(timshen): Add a custom syntax for this.
def LHLO_CaseOp: LHLO_Op<"case", [
SingleBlockImplicitTerminator<"TerminatorOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
let summary = "Switch-Case operator";
let description = [{
Returns the result of executing `branches[index]`. If
`index` is < 0 or >= N, then `branches[N-1] is executed as
the default branch.
Each branch `branches[b]` must take in a single argument of same type as
`branch_operands[b]` and will be invoked with `branch_operands[b]`. The type
of the returned value of each branch must be the same.
Note that only one of the branches will be executed depending on the value
of index.
See https://www.tensorflow.org/xla/operation_semantics#conditional.
}];
let arguments = (ins Arg<LHLO_PredOrIntBuffer, "", [MemRead]>:$index);
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
}
// TODO(timshen): Add a custom syntax for this.
def LHLO_WhileOp: LHLO_Op<"while", [
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
let summary = "While operator";
let description = [{
Returns the result of executing a body function until the cond body returns
true.
See https://www.tensorflow.org/xla/operation_semantics#while.
}];
let arguments = (ins
Arg<Variadic<LHLO_PredBuffer>, "", [MemWrite]>:$cond_val,
OptionalAttr<I64Attr>:$trip_count);
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
}
def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]> {
let summary = "CustomCall operator";
let description = [{
A custom call invokes code external to XLA. The `args` are passed to the
external code, and the external code is expected to produce a result of the
given type. The exact mechanism is backend-specific. For example, in the CPU
backend, a call instruction is emitted which targets a symbol with the name
`call_target_name`.
`call_target_name` and `backend_config` can be arbitrary strings, but
`call_target_name` should be short as it may be used in labels.
`backend_config` can encode arbitrarily large amounts of information.
See https://www.tensorflow.org/xla/operation_semantics#customcall.
}];
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$args,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
StrAttr:$call_target_name,
DefaultValuedOptionalAttr<BoolAttr, "false">:$has_side_effect,
OptionalAttr<AnyAttrOf<[StrAttr, DictionaryAttr]>>:$backend_config,
// TODO(b/189822916): Remove this field when all clients are migrated to
// the status-returning API.
DefaultValuedOptionalAttr<MHLO_CustomCallApiVersionAttr,
"mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL">:
$api_version,
OptionalAttr<CustomCallTargetArgMappingAttr>:$target_arg_mapping
);
let hasVerifier = 1;
let regions = (region AnyRegion:$called_computation);
}
//===----------------------------------------------------------------------===//
// LMHLO tuple op definitions.
//===----------------------------------------------------------------------===//
def LHLO_CompareOp: LHLO_Op<"compare", []> {
let summary = "Comparison operator";
let description = [{
Compares `lhs` and `rhs` elementwise according to `comparison_direction`
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
types, SIGNED for signed element types and UNSIGNED for unsigned element
types.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_PredBuffer, "", [MemWrite]>:$out,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
MHLO_ComparisonDirectionAttr:$comparison_direction,
OptionalAttr<MHLO_ComparisonTypeAttr>:$compare_type
);
}
//===----------------------------------------------------------------------===//
// LMHLO Slice definitions.
//===----------------------------------------------------------------------===//
def LHLO_SliceOp: LHLO_Op<
"slice",
[AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$start_indices,
I64ElementsAttr:$limit_indices,
I64ElementsAttr:$strides
);
}
def LHLO_DynamicSliceOp: LHLO_Op<"dynamic_slice",
[AllElementTypesMatch<["operand", "output"]>]> {
let summary = "Dynamic Slice operator";
let description = [{
Extracts a sub-array from the input array at dynamic start_indices.
See https://www.tensorflow.org/xla/operation_semantics#dynamicslice.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$slice_sizes
);
}
def LHLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
let summary = "Dynamic Update Slice operator";
let description = [{
DynamicUpdateSlice generates a result which is the value of the input array
operand, with a slice update overwritten at start_indices.
See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$update,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
//===----------------------------------------------------------------------===//
// LMHLO Other op definitions.
//===----------------------------------------------------------------------===//
def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []> {
let summary = "Batch Normalization Gradient";
let description = [{
Calculates gradients of batch norm.
See https://www.tensorflow.org/xla/operation_semantics#batchnormgrad
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_operand, // gradient of $operand.
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_scale,
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_offset,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []> {
let summary = "Batch Normalization for Inference";
let description = [{
Normalizes an array across batch and spatial dimensions.
See https://www.tensorflow.org/xla/operation_semantics#batchnorminference
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []> {
let summary = "Batch Normalization for Training";
let description = [{
Normalizes an array across batch and spatial dimensions.
See https://www.tensorflow.org/xla/operation_semantics#batchnormtraining
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_mean,
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_var,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
[]> {
let summary = "Broadcast a tensor to a higher rank by prepending dimensions";
let description = [{
Broadcasts the operand tensor to a higher rank by prepending
`broadcast_sizes` to the dimensions. The current values of the operand are
copied into the other dimensions.
This is a more limited form of broadcasting, that corresponds to the XLA
client Broadcast method. For a more general form of broadcasting, see the
BroadcastInDimOp.
See https://www.tensorflow.org/xla/operation_semantics#broadcast.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$broadcast_sizes
);
}
def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim",
[]> {
let summary = "Broadcast a tensor into the given shape by adding dimensions.";
let description = [{
Broadcasts the `operand` tensor to a higher rank. This is not the limited
form of broadcasting exposed as the XLA client broadcast op, but rather the
more powerful "InDim" broadcasting, which is closer to the HLO broadcast op
and exposed in the XLA client BroadcastInDim method.
`broadcast_dimensions` maps the operand dimension number to the target shape
dimension number. It must have the same size as the rank of the operand. The
mapped dimensions must either be the same size or the dimension being
broadcast from must be size 1 (degenerate broadcasting).
For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The
The scalar value will be broadcast to every element in the target shape.
See https://www.tensorflow.org/xla/broadcasting.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
BroadcastDimAttr:$broadcast_dimensions
);
}
def LHLO_ClampOp : LHLO_Op<"clamp", []> {
let summary = "Clamp operator";
let description = [{
Clamps an operand to within the range between a minimum and maximum value.
Note: All three arrays must be the same shape. Alternatively, as a
restricted form of broadcasting, min and/or max can be a scalar (0D
tensor) of the element type of the tensor operand.
See https://www.tensorflow.org/xla/operation_semantics#clamp.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$min,
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$max,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []> {
let summary = "XLA's concatenate op";
let description = [{
Concatenates a set of tensors along the specified dimension.
See https://www.tensorflow.org/xla/operation_semantics#concatenate.
}];
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64Attr:$dimension
);
}
def LHLO_ConvolutionOp : LHLO_Op<"convolution", []> {
let summary = "Convolution operator";
let description = [{
Computes a convolution of the kind used in neural networks.
See https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
}];
let arguments = !con(
(ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output),
MHLO_ConvolutionAttributes.attributes);
code extraClassDeclaration = [{
bool hasWindowReversal() {
auto reversal = getWindowReversalAttr();
return reversal && llvm::any_of(reversal.getValues<bool>(),
[](bool v) { return v; });
}
}];
let assemblyFormat = [{
`(`operands`)`
`dim_numbers` `=` custom<ConvolutionDimensions>($dimension_numbers) `,`
`window` `=` `{` custom<WindowAttributes>($window_strides, $padding,
$lhs_dilation, $rhs_dilation,
$window_reversal) `}`
attr-dict `:` functional-type(operands, results)
}];
}
def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]> {
let summary = "Copy operator";
let description = [{
Returns a copy of `operand`.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
let extraClassDeclaration = [{
Value getSource() { return getOperand();}
Value getTarget() { return getOutput(); }
}];
}
def LHLO_DotOp: LHLO_Op<"dot", []> {
let summary = "Dot operator";
let description = [{
Performs dot products between vectors, vector/matrix and matrix/matrix
multiplication.
See https://www.tensorflow.org/xla/operation_semantics#dot.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
MHLO_DotDimensionNumbers:$dot_dimension_numbers,
MHLO_PrecisionConfigAttr:$precision_config,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_GatherOp: LHLO_Op<"gather", []> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
MHLO_GatherDimensionNumbers:$dimension_numbers,
I64ElementsAttr:$slice_sizes,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_ReshapeOp: LHLO_Op<"reshape", []> {
let summary = "Reshape operator";
let description = [{
Reshapes the dimensions of `operand` into a new configuration.
See https://www.tensorflow.org/xla/operation_semantics#reshape.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_ScatterOp: LHLO_Op<"scatter", []> {
let summary = "Scatter operator";
let description = [{
Generates a result which is the value of the input array `operand`,
with several slices (at indices specified by `scatter_indices`)
updated with the values in `updates` using `update_computation`.
See https://www.tensorflow.org/xla/operation_semantics#scatter.
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices,
Arg<LHLO_Buffer, "", [MemRead]>:$updates,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
MHLO_ScatterDimensionNumbers:$scatter_dimension_numbers,
DefaultValuedOptionalAttr<BoolAttr, "false">:$indices_are_sorted,
DefaultValuedOptionalAttr<BoolAttr, "false">:$unique_indices
);
let regions = (region SizedRegion<1>:$update_computation);
}
def LHLO_SelectOp: LHLO_Op<"select", [Elementwise]> {
let summary = "Select operator";
let description = [{
Constructs an output tensor from the elements of `on_true` and `on_false`
based on the values of `pred`.
`pred`, `on_true` and `on_false` must be broadcast compatible.
}];
let arguments = (ins
Arg<LHLO_PredBuffer, "", [MemRead]>:$pred,
Arg<LHLO_Buffer, "", [MemRead]>:$on_true,
Arg<LHLO_Buffer, "", [MemRead]>:$on_false,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}