-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
simd.mojo
2915 lines (2359 loc) · 89.6 KB
/
simd.mojo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2024, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
"""Implements SIMD struct.
These are Mojo built-ins, so you don't need to import them.
"""
from bit import pop_count
from sys import (
llvm_intrinsic,
has_neon,
is_x86,
triple_is_nvidia_cuda,
simdwidthof,
_RegisterPackType,
)
from builtin._math import Ceilable, CeilDivable, Floorable, Truncable
from builtin.hash import _hash_simd
from memory import bitcast
from utils.numerics import (
FPUtils,
isnan as _isnan,
nan as _nan,
max_finite as _max_finite,
min_finite as _min_finite,
max_or_inf as _max_or_inf,
min_or_neg_inf as _min_or_neg_inf,
)
from utils._visualizers import lldb_formatter_wrapping_type
from utils import InlineArray, StringSlice
from .dtype import (
_integral_type_of,
_get_dtype_printf_format,
_scientific_notation_digits,
)
from .io import _snprintf_scalar, _printf, _print_fmt
from .string import _calc_initial_buffer_size, _calc_format_buffer_size
# ===----------------------------------------------------------------------=== #
# Type Aliases
# ===----------------------------------------------------------------------=== #
alias Scalar = SIMD[size=1]
"""Represents a scalar dtype."""
alias Int8 = Scalar[DType.int8]
"""Represents an 8-bit signed scalar integer."""
alias UInt8 = Scalar[DType.uint8]
"""Represents an 8-bit unsigned scalar integer."""
alias Int16 = Scalar[DType.int16]
"""Represents a 16-bit signed scalar integer."""
alias UInt16 = Scalar[DType.uint16]
"""Represents a 16-bit unsigned scalar integer."""
alias Int32 = Scalar[DType.int32]
"""Represents a 32-bit signed scalar integer."""
alias UInt32 = Scalar[DType.uint32]
"""Represents a 32-bit unsigned scalar integer."""
alias Int64 = Scalar[DType.int64]
"""Represents a 64-bit signed scalar integer."""
alias UInt64 = Scalar[DType.uint64]
"""Represents a 64-bit unsigned scalar integer."""
alias BFloat16 = Scalar[DType.bfloat16]
"""Represents a 16-bit brain floating point value."""
alias Float16 = Scalar[DType.float16]
"""Represents a 16-bit floating point value."""
alias Float32 = Scalar[DType.float32]
"""Represents a 32-bit floating point value."""
alias Float64 = Scalar[DType.float64]
"""Represents a 64-bit floating point value."""
# ===----------------------------------------------------------------------=== #
# Utilities
# ===----------------------------------------------------------------------=== #
@always_inline("nodebug")
fn _simd_construction_checks[type: DType, size: Int]():
"""Checks if the SIMD size is valid.
The SIMD size is valid if it is a power of two and is positive.
Parameters:
type: The data type of SIMD vector elements.
size: The number of elements in the SIMD vector.
"""
constrained[type != DType.invalid, "simd type cannot be DType.invalid"]()
constrained[size > 0, "simd width must be > 0"]()
constrained[size & (size - 1) == 0, "simd width must be power of 2"]()
constrained[
type != DType.bfloat16 or not has_neon(),
"bf16 is not supported for ARM architectures",
]()
@always_inline("nodebug")
fn _unchecked_zero[type: DType, size: Int]() -> SIMD[type, size]:
var zero = __mlir_op.`pop.cast`[
_type = __mlir_type[`!pop.scalar<`, type.value, `>`]
](
__mlir_op.`kgen.param.constant`[
_type = __mlir_type[`!pop.scalar<index>`],
value = __mlir_attr[`#pop.simd<0> : !pop.scalar<index>`],
]()
)
return SIMD[type, size] {
value: __mlir_op.`pop.simd.splat`[
_type = __mlir_type[`!pop.simd<`, size.value, `, `, type.value, `>`]
](zero)
}
# ===----------------------------------------------------------------------=== #
# SIMD
# ===----------------------------------------------------------------------=== #
@lldb_formatter_wrapping_type
@register_passable("trivial")
struct SIMD[type: DType, size: Int = simdwidthof[type]()](
Absable,
Boolable,
Ceilable,
CeilDivable,
CollectionElement,
CollectionElementNew,
Floorable,
Hashable,
Intable,
Powable,
Roundable,
Sized,
Stringable,
Truncable,
Representable,
):
"""Represents a small vector that is backed by a hardware vector element.
SIMD allows a single instruction to be executed across the multiple data
elements of the vector.
Constraints:
The size of the SIMD vector to be positive and a power of 2.
Parameters:
type: The data type of SIMD vector elements.
size: The size of the SIMD vector.
"""
alias _Mask = SIMD[DType.bool, size]
alias element_type = type
var value: __mlir_type[`!pop.simd<`, size.value, `, `, type.value, `>`]
"""The underlying storage for the vector."""
alias MAX = Self(_max_or_inf[type]())
"""Gets the maximum value for the SIMD value, potentially +inf."""
alias MIN = Self(_min_or_neg_inf[type]())
"""Gets the minimum value for the SIMD value, potentially -inf."""
alias MAX_FINITE = Self(_max_finite[type]())
"""Returns the maximum finite value of SIMD value."""
alias MIN_FINITE = Self(_min_finite[type]())
"""Returns the minimum (lowest) finite value of SIMD value."""
@always_inline("nodebug")
fn __init__(inout self):
"""Default initializer of the SIMD vector.
By default the SIMD vectors are initialized to all zeros.
"""
_simd_construction_checks[type, size]()
self = _unchecked_zero[type, size]()
@always_inline("nodebug")
fn __init__(inout self, value: SIMD[DType.float64, 1]):
"""Initializes the SIMD vector with a float.
The value is splatted across all the elements of the SIMD
vector.
Args:
value: The input value.
"""
_simd_construction_checks[type, size]()
var casted = __mlir_op.`pop.cast`[
_type = __mlir_type[`!pop.simd<1,`, type.value, `>`]
](value.value)
var vec = __mlir_op.`pop.simd.splat`[
_type = __mlir_type[`!pop.simd<`, size.value, `, `, type.value, `>`]
](casted)
self.value = vec
@always_inline("nodebug")
fn __init__(inout self, *, other: SIMD[type, size]):
"""Explicitly copy the provided value.
Args:
other: The value to copy.
"""
self.__copyinit__(other)
@always_inline("nodebug")
fn __init__(inout self, value: Int):
"""Initializes the SIMD vector with an integer.
The integer value is splatted across all the elements of the SIMD
vector.
Args:
value: The input value.
"""
_simd_construction_checks[type, size]()
var t0 = __mlir_op.`pop.cast_from_builtin`[
_type = __mlir_type.`!pop.scalar<index>`
](value.value)
var casted = __mlir_op.`pop.cast`[
_type = __mlir_type[`!pop.simd<1,`, type.value, `>`]
](t0)
self.value = __mlir_op.`pop.simd.splat`[
_type = __mlir_type[`!pop.simd<`, size.value, `, `, type.value, `>`]
](casted)
@always_inline("nodebug")
fn __init__(inout self, value: IntLiteral):
"""Initializes the SIMD vector with an integer.
The integer value is splatted across all the elements of the SIMD
vector.
Args:
value: The input value.
"""
_simd_construction_checks[type, size]()
var tn1 = __mlir_op.`kgen.int_literal.convert`[
_type = __mlir_type.si128
](value.value)
var t0 = __mlir_op.`pop.cast_from_builtin`[
_type = __mlir_type.`!pop.scalar<si128>`
](tn1)
var casted = __mlir_op.`pop.cast`[
_type = __mlir_type[`!pop.simd<1,`, type.value, `>`]
](t0)
self.value = __mlir_op.`pop.simd.splat`[
_type = __mlir_type[`!pop.simd<`, size.value, `, `, type.value, `>`]
](casted)
@always_inline("nodebug")
fn __init__(inout self, value: Bool):
"""Initializes the SIMD vector with a bool value.
The bool value is splatted across all elements of the SIMD vector.
Args:
value: The bool value.
"""
_simd_construction_checks[type, size]()
var casted = __mlir_op.`pop.cast`[
_type = __mlir_type[`!pop.simd<1,`, type.value, `>`]
](value._as_scalar_bool())
self.value = __mlir_op.`pop.simd.splat`[
_type = __mlir_type[`!pop.simd<`, size.value, `, `, type.value, `>`]
](casted)
@always_inline("nodebug")
fn __init__(
inout self,
value: __mlir_type[`!pop.simd<`, size.value, `, `, type.value, `>`],
):
"""Initializes the SIMD vector with the underlying mlir value.
Args:
value: The input value.
"""
_simd_construction_checks[type, size]()
self.value = value
# Construct via a variadic type which has the same number of elements as
# the SIMD value.
@always_inline("nodebug")
fn __init__(inout self, *elems: Scalar[type]):
"""Constructs a SIMD vector via a variadic list of elements.
If there is just one input value, then it is splatted to all elements
of the SIMD vector. Otherwise, the input values are assigned to the
corresponding elements of the SIMD vector.
Constraints:
The number of input values is 1 or equal to size of the SIMD
vector.
Args:
elems: The variadic list of elements from which the SIMD vector is
constructed.
"""
_simd_construction_checks[type, size]()
var num_elements: Int = len(elems)
if num_elements == 1:
# Construct by broadcasting a scalar.
self.value = __mlir_op.`pop.simd.splat`[
_type = __mlir_type[
`!pop.simd<`,
size.value,
`, `,
type.value,
`>`,
]
](elems[0].value)
return
# TODO: Make this a compile-time check when possible.
debug_assert(
size == num_elements,
(
"mismatch in the number of elements in the SIMD variadic"
" constructor"
),
)
self = Self()
@parameter
for i in range(size):
self[i] = elems[i]
@always_inline("nodebug")
fn __init__(inout self, value: FloatLiteral):
"""Initializes the SIMD vector with a float.
The value is splatted across all the elements of the SIMD
vector.
Args:
value: The input value.
"""
_simd_construction_checks[type, size]()
# TODO (#36686): This introduces uneeded casts here to work around
# parameter if issues.
@parameter
if type == DType.float16:
self = SIMD[type, size](
__mlir_op.`pop.simd.splat`[
_type = __mlir_type[
`!pop.simd<`, size.value, `,`, type.value, `>`
]
](
__mlir_op.`pop.cast`[
_type = __mlir_type[`!pop.scalar<`, type.value, `>`]
](
__mlir_op.`pop.cast_from_builtin`[
_type = __mlir_type[`!pop.scalar<f16>`]
](
__mlir_op.`kgen.float_literal.convert`[
_type = __mlir_type.f16
](value.value)
)
)
)
)
elif type == DType.bfloat16:
self = Self(
__mlir_op.`pop.simd.splat`[
_type = __mlir_type[
`!pop.simd<`, size.value, `,`, type.value, `>`
]
](
__mlir_op.`pop.cast`[
_type = __mlir_type[`!pop.scalar<`, type.value, `>`]
](
__mlir_op.`pop.cast_from_builtin`[
_type = __mlir_type[`!pop.scalar<bf16>`]
](
__mlir_op.`kgen.float_literal.convert`[
_type = __mlir_type.bf16
](value.value)
)
)
)
)
elif type == DType.float32:
self = Self(
__mlir_op.`pop.simd.splat`[
_type = __mlir_type[
`!pop.simd<`, size.value, `,`, type.value, `>`
]
](
__mlir_op.`pop.cast`[
_type = __mlir_type[`!pop.scalar<`, type.value, `>`]
](
__mlir_op.`pop.cast_from_builtin`[
_type = __mlir_type[`!pop.scalar<f32>`]
](
__mlir_op.`kgen.float_literal.convert`[
_type = __mlir_type.f32
](value.value)
)
)
)
)
else:
self = Self(
__mlir_op.`pop.simd.splat`[
_type = __mlir_type[
`!pop.simd<`, size.value, `,`, type.value, `>`
]
](
__mlir_op.`pop.cast`[
_type = __mlir_type[`!pop.scalar<`, type.value, `>`]
](
__mlir_op.`pop.cast_from_builtin`[
_type = __mlir_type[`!pop.scalar<f64>`]
](
__mlir_op.`kgen.float_literal.convert`[
_type = __mlir_type.f64
](value.value)
)
)
)
)
@always_inline("nodebug")
fn __len__(self) -> Int:
"""Gets the length of the SIMD vector.
Returns:
The length of the SIMD vector.
"""
return self.size
@always_inline("nodebug")
fn __bool__(self) -> Bool:
"""Converts the SIMD scalar into a boolean value.
Constraints:
The size of the SIMD vector must be 1.
Returns:
True if the SIMD scalar is non-zero and False otherwise.
"""
constrained[
size == 1,
(
"The truth value of a SIMD vector with more than one element is"
" ambiguous. Use the builtin `any()` or `all()` functions"
" instead."
),
]()
return rebind[Scalar[DType.bool]](self.cast[DType.bool]()).value
@staticmethod
@always_inline("nodebug")
fn splat(x: Scalar[type]) -> Self:
"""Splats (broadcasts) the element onto the vector.
Args:
x: The input scalar value.
Returns:
A new SIMD vector whose elements are the same as the input value.
"""
_simd_construction_checks[type, size]()
return Self {
value: __mlir_op.`pop.simd.splat`[
_type = __mlir_type[
`!pop.simd<`, size.value, `, `, type.value, `>`
]
](x.value)
}
@always_inline("nodebug")
fn cast[target: DType](self) -> SIMD[target, size]:
"""Casts the elements of the SIMD vector to the target element type.
Parameters:
target: The target DType.
Returns:
A new SIMD vector whose elements have been casted to the target
element type.
"""
@parameter
if type == target:
return rebind[SIMD[target, size]](self)
@parameter
if has_neon() and (type == DType.bfloat16 or target == DType.bfloat16):
# BF16 support on neon systems is not supported.
return _unchecked_zero[target, size]()
@parameter
if type == DType.bool:
return self.select(SIMD[target, size](1), SIMD[target, size](0))
elif target == DType.bool:
return rebind[SIMD[target, size]](self != 0)
elif type == DType.bfloat16:
var cast_result = _bfloat16_to_f32(
rebind[SIMD[DType.bfloat16, size]](self)
).cast[target]()
return rebind[SIMD[target, size]](cast_result)
elif target == DType.bfloat16:
return rebind[SIMD[target, size]](
_f32_to_bfloat16(self.cast[DType.float32]())
)
elif target == DType.address:
var index_val = __mlir_op.`pop.cast`[
_type = __mlir_type[`!pop.simd<`, size.value, `, index>`]
](self.value)
var tmp = SIMD[DType.address, size](
__mlir_op.`pop.index_to_pointer`[
_type = __mlir_type[
`!pop.simd<`,
size.value,
`, address >`,
]
](index_val)
)
return rebind[SIMD[target, size]](tmp)
elif (type == DType.address) and target.is_integral():
var index_tmp = SIMD[DType.index, size](
__mlir_op.`pop.pointer_to_index`[
_type = __mlir_type[
`!pop.simd<`,
size.value,
`, `,
DType.index.value,
`>`,
]
](
rebind[
__mlir_type[
`!pop.simd<`,
size.value,
`, address >`,
]
](self.value)
)
)
return index_tmp.cast[target]()
else:
return __mlir_op.`pop.cast`[
_type = __mlir_type[
`!pop.simd<`,
size.value,
`, `,
target.value,
`>`,
]
](self.value)
@always_inline("nodebug")
fn __int__(self) -> Int:
"""Casts to the value to an Int. If there is a fractional component,
then the fractional part is truncated.
Constraints:
The size of the SIMD vector must be 1.
Returns:
The value as an integer.
"""
constrained[size == 1, "expected a scalar type"]()
return __mlir_op.`pop.cast`[_type = __mlir_type.`!pop.scalar<index>`](
rebind[Scalar[type]](self).value
)
@always_inline
fn __str__(self) -> String:
"""Get the SIMD as a string.
Returns:
A string representation.
"""
return String.format_sequence(self)
@always_inline
fn __repr__(self) -> String:
"""Get the representation of the SIMD value e.g. "SIMD[DType.int8, 2](1, 2)".
Returns:
The representation of the SIMD value.
"""
var output = String()
var writer = output._unsafe_to_formatter()
self.format_to[use_scientific_notation=True](writer)
var values = output.as_string_slice()
@parameter
if size > 1:
# TODO: Fix when slice indexing is implemented on StringSlice
values = StringSlice(unsafe_from_utf8=output.as_bytes_slice()[1:-1])
return (
"SIMD[" + type.__repr__() + ", " + str(size) + "](" + values + ")"
)
@always_inline
fn format_to(self, inout writer: Formatter):
"""
Formats this SIMD value to the provided formatter.
Args:
writer: The formatter to write to.
"""
self.format_to[use_scientific_notation=False](writer)
# This overload is required to keep SIMD compliant with the Formattable
# trait, and the call to `String.format_sequence(self)` in SIMD.__str__ will
# fail to compile.
fn format_to[use_scientific_notation: Bool](self, inout writer: Formatter):
"""
Formats this SIMD value to the provided formatter.
Parameters:
use_scientific_notation: Whether floats should use scientific notation.
This parameter does not apply to integer types.
Args:
writer: The formatter to write to.
"""
# Print an opening `[`.
@parameter
if size > 1:
writer.write_str("[")
# Print each element.
for i in range(size):
var element = self[i]
# Print separators between each element.
if i != 0:
writer.write_str(", ")
@parameter
if triple_is_nvidia_cuda():
@parameter
if type.is_floating_point():
# get_dtype_printf_format hardcodes 17 digits of precision.
_printf["%g"](element)
else:
# FIXME(MSTDL-406):
# This prints "out of band" with the `Formatter` passed
# in, meaning this will only work if `Formatter` is an
# unbuffered wrapper around printf (which Formatter.stdout
# currently is by default).
#
# This is a workaround to permit debug formatting of
# floating-point values on GPU, where printing to stdout
# is the only way the Formatter framework is currently
# used.
_printf[_get_dtype_printf_format[type]()](element)
else:
@parameter
if use_scientific_notation and type.is_floating_point():
alias float_format = "%." + _scientific_notation_digits[
type
]() + "e"
_format_scalar[type, float_format](writer, element)
else:
_format_scalar(writer, element)
# Print a closing `]`.
@parameter
if size > 1:
writer.write_str("]")
@always_inline("nodebug")
fn __add__(self, rhs: Self) -> Self:
"""Computes `self + rhs`.
Args:
rhs: The rhs value.
Returns:
A new vector whose element at position `i` is computed as
`self[i] + rhs[i]`.
"""
constrained[type.is_numeric(), "the SIMD type must be numeric"]()
return __mlir_op.`pop.add`(self.value, rhs.value)
@always_inline("nodebug")
fn __sub__(self, rhs: Self) -> Self:
"""Computes `self - rhs`.
Args:
rhs: The rhs value.
Returns:
A new vector whose element at position `i` is computed as
`self[i] - rhs[i]`.
"""
constrained[type.is_numeric(), "the SIMD type must be numeric"]()
return __mlir_op.`pop.sub`(self.value, rhs.value)
@always_inline("nodebug")
fn __mul__(self, rhs: Self) -> Self:
"""Computes `self * rhs`.
Args:
rhs: The rhs value.
Returns:
A new vector whose element at position `i` is computed as
`self[i] * rhs[i]`.
"""
@parameter
if type == DType.bool:
return (rebind[Self._Mask](self) & rebind[Self._Mask](rhs)).cast[
type
]()
constrained[type.is_numeric(), "the SIMD type must be numeric"]()
return __mlir_op.`pop.mul`(self.value, rhs.value)
@always_inline("nodebug")
fn __truediv__(self, rhs: Self) -> Self:
"""Computes `self / rhs`.
Args:
rhs: The rhs value.
Returns:
A new vector whose element at position `i` is computed as
`self[i] / rhs[i]`.
"""
constrained[type.is_numeric(), "the SIMD type must be numeric"]()
return __mlir_op.`pop.div`(self.value, rhs.value)
@always_inline("nodebug")
fn __floordiv__(self, rhs: Self) -> Self:
"""Returns the division of self and rhs rounded down to the nearest
integer.
Constraints:
The element type of the SIMD vector must be numeric.
Args:
rhs: The value to divide with.
Returns:
`floor(self / rhs)` value.
"""
constrained[type.is_numeric(), "the type must be numeric"]()
if not any(rhs):
# this should raise an exception.
return 0
var div = self / rhs
@parameter
if type.is_floating_point():
return div.__floor__()
elif type.is_unsigned():
return div
else:
if all((self > 0) & (rhs > 0)):
return div
var mod = self - div * rhs
var mask = ((rhs < 0) ^ (self < 0)) & (mod != 0)
return div - mask.cast[type]()
@always_inline("nodebug")
fn __rfloordiv__(self, rhs: Self) -> Self:
"""Returns the division of rhs and self rounded down to the nearest
integer.
Constraints:
The element type of the SIMD vector must be numeric.
Args:
rhs: The value to divide by self.
Returns:
`floor(rhs / self)` value.
"""
constrained[type.is_numeric(), "the type must be numeric"]()
return rhs // self
@always_inline("nodebug")
fn __mod__(self, rhs: Self) -> Self:
"""Returns the remainder of self divided by rhs.
Args:
rhs: The value to divide on.
Returns:
The remainder of dividing self by rhs.
"""
constrained[type.is_numeric(), "the type must be numeric"]()
if not any(rhs):
# this should raise an exception.
return 0
@parameter
if type.is_unsigned():
return __mlir_op.`pop.rem`(self.value, rhs.value)
else:
var div = self / rhs
@parameter
if type.is_floating_point():
div = llvm_intrinsic["llvm.trunc", Self, has_side_effect=False](
div
)
var mod = self - div * rhs
var mask = ((rhs < 0) ^ (self < 0)) & (mod != 0)
return mod + mask.select(rhs, Self(0))
@always_inline("nodebug")
fn __rmod__(self, value: Self) -> Self:
"""Returns `value mod self`.
Args:
value: The other value.
Returns:
`value mod self`.
"""
constrained[type.is_numeric(), "the type must be numeric"]()
return value % self
@always_inline("nodebug")
fn __pow__(self, exp: Int) -> Self:
"""Computes the vector raised to the power of the input integer value.
Args:
exp: The exponent value.
Returns:
A SIMD vector where each element is raised to the power of the
specified exponent value.
"""
constrained[type.is_numeric(), "the SIMD type must be numeric"]()
return _pow[type, size, DType.index](self, exp)
# TODO(#22771): remove this overload.
@always_inline("nodebug")
fn __pow__(self, exp: Self) -> Self:
"""Computes the vector raised elementwise to the right hand side power.
Args:
exp: The exponent value.
Returns:
A SIMD vector where each element is raised to the power of the
specified exponent value.
"""
constrained[type.is_numeric(), "the SIMD type must be numeric"]()
return _pow(self, exp)
@always_inline("nodebug")
fn __lt__(self, rhs: Self) -> Self._Mask:
"""Compares two SIMD vectors using less-than comparison.
Args:
rhs: The rhs of the operation.
Returns:
A new bool SIMD vector of the same size whose element at position
`i` is True or False depending on the expression
`self[i] < rhs[i]`.
"""
return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred lt>`](
self.value, rhs.value
)
@always_inline("nodebug")
fn __le__(self, rhs: Self) -> Self._Mask:
"""Compares two SIMD vectors using less-than-or-equal comparison.
Args:
rhs: The rhs of the operation.
Returns:
A new bool SIMD vector of the same size whose element at position
`i` is True or False depending on the expression
`self[i] <= rhs[i]`.
"""
return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred le>`](
self.value, rhs.value
)
@always_inline("nodebug")
fn __eq__(self, rhs: Self) -> Self._Mask:
"""Compares two SIMD vectors using equal-to comparison.
Args:
rhs: The rhs of the operation.
Returns:
A new bool SIMD vector of the same size whose element at position
`i` is True or False depending on the expression
`self[i] == rhs[i]`.
"""
@parameter # Because of #30525, we roll our own implementation for eq.
if has_neon() and type == DType.bfloat16:
var int_self = bitcast[_integral_type_of[type](), size](self)
var int_rhs = bitcast[_integral_type_of[type](), size](rhs)
return int_self == int_rhs
return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred eq>`](
self.value, rhs.value
)
@always_inline("nodebug")
fn __ne__(self, rhs: Self) -> Self._Mask:
"""Compares two SIMD vectors using not-equal comparison.
Args:
rhs: The rhs of the operation.
Returns:
A new bool SIMD vector of the same size whose element at position
`i` is True or False depending on the expression
`self[i] != rhs[i]`.
"""
@parameter # Because of #30525, we roll our own implementation for ne.
if has_neon() and type == DType.bfloat16:
var int_self = bitcast[_integral_type_of[type](), size](self)
var int_rhs = bitcast[_integral_type_of[type](), size](rhs)
return int_self != int_rhs
return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred ne>`](
self.value, rhs.value
)
@always_inline("nodebug")
fn __gt__(self, rhs: Self) -> Self._Mask:
"""Compares two SIMD vectors using greater-than comparison.
Args:
rhs: The rhs of the operation.
Returns:
A new bool SIMD vector of the same size whose element at position
`i` is True or False depending on the expression
`self[i] > rhs[i]`.
"""
return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred gt>`](
self.value, rhs.value
)
@always_inline("nodebug")
fn __ge__(self, rhs: Self) -> Self._Mask:
"""Compares two SIMD vectors using greater-than-or-equal comparison.
Args:
rhs: The rhs of the operation.
Returns:
A new bool SIMD vector of the same size whose element at position
`i` is True or False depending on the expression
`self[i] >= rhs[i]`.
"""
return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop<cmp_pred ge>`](
self.value, rhs.value
)
# ===------------------------------------------------------------------=== #
# Unary operations.
# ===------------------------------------------------------------------=== #
@always_inline("nodebug")
fn __pos__(self) -> Self:
"""Defines the unary `+` operation.