-
Notifications
You must be signed in to change notification settings - Fork 493
/
Copy pathtensor_methods.cpp
3081 lines (2775 loc) · 132 KB
/
tensor_methods.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include "torch_xla/csrc/tensor_methods.h"
#include <ATen/core/Reduction.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/lazy/core/helpers.h>
#include <torch/csrc/lazy/core/util.h>
#include <algorithm>
#include <functional>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "torch_xla/csrc/LazyIr.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/data_ops.h"
#include "torch_xla/csrc/dtype.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/adam_optimizer_step.h"
#include "torch_xla/csrc/ops/adaptive_max_pool2d.h"
#include "torch_xla/csrc/ops/all_gather.h"
#include "torch_xla/csrc/ops/all_reduce.h"
#include "torch_xla/csrc/ops/all_to_all.h"
#include "torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h"
#include "torch_xla/csrc/ops/amp_update_scale.h"
#include "torch_xla/csrc/ops/arithmetic_ir_ops.h"
#include "torch_xla/csrc/ops/as_strided.h"
#include "torch_xla/csrc/ops/avg_pool_nd.h"
#include "torch_xla/csrc/ops/avg_pool_nd_backward.h"
#include "torch_xla/csrc/ops/bernoulli.h"
#include "torch_xla/csrc/ops/cast.h"
#include "torch_xla/csrc/ops/cat.h"
#include "torch_xla/csrc/ops/cdist.h"
#include "torch_xla/csrc/ops/collective_permute.h"
#include "torch_xla/csrc/ops/constant.h"
#include "torch_xla/csrc/ops/constant_pad_nd.h"
#include "torch_xla/csrc/ops/convolution_backward_overrideable.h"
#include "torch_xla/csrc/ops/convolution_overrideable.h"
#include "torch_xla/csrc/ops/count_nonzero.h"
#include "torch_xla/csrc/ops/cumprod.h"
#include "torch_xla/csrc/ops/cumsum.h"
#include "torch_xla/csrc/ops/custom_sharding.h"
#include "torch_xla/csrc/ops/dequant_tensor.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/diagonal.h"
#include "torch_xla/csrc/ops/discrete_uniform.h"
#include "torch_xla/csrc/ops/einsum.h"
#include "torch_xla/csrc/ops/einsum_backward.h"
#include "torch_xla/csrc/ops/expand.h"
#include "torch_xla/csrc/ops/expand_symint.h"
#include "torch_xla/csrc/ops/exponential.h"
#include "torch_xla/csrc/ops/flip.h"
#include "torch_xla/csrc/ops/gather.h"
#include "torch_xla/csrc/ops/generic.h"
#include "torch_xla/csrc/ops/generic_slice.h"
#include "torch_xla/csrc/ops/get_dimensions_size.h"
#include "torch_xla/csrc/ops/hardtanh_backward.h"
#include "torch_xla/csrc/ops/index_ops.h"
#include "torch_xla/csrc/ops/index_select.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/kth_value.h"
#include "torch_xla/csrc/ops/linear_interpolation.h"
#include "torch_xla/csrc/ops/linspace.h"
#include "torch_xla/csrc/ops/log_softmax.h"
#include "torch_xla/csrc/ops/logsumexp.h"
#include "torch_xla/csrc/ops/mark_tensor.h"
#include "torch_xla/csrc/ops/masked_scatter.h"
#include "torch_xla/csrc/ops/masked_select.h"
#include "torch_xla/csrc/ops/max_in_dim.h"
#include "torch_xla/csrc/ops/max_pool_nd.h"
#include "torch_xla/csrc/ops/max_pool_nd_backward.h"
#include "torch_xla/csrc/ops/max_unpool_nd.h"
#include "torch_xla/csrc/ops/mean.h"
#include "torch_xla/csrc/ops/min_in_dim.h"
#include "torch_xla/csrc/ops/mse_loss.h"
#include "torch_xla/csrc/ops/mse_loss_backward.h"
#include "torch_xla/csrc/ops/multinomial.h"
#include "torch_xla/csrc/ops/native_batch_norm_backward.h"
#include "torch_xla/csrc/ops/native_batch_norm_forward.h"
#include "torch_xla/csrc/ops/native_dropout.h"
#include "torch_xla/csrc/ops/nll_loss.h"
#include "torch_xla/csrc/ops/nll_loss2d.h"
#include "torch_xla/csrc/ops/nll_loss2d_backward.h"
#include "torch_xla/csrc/ops/nll_loss_backward.h"
#include "torch_xla/csrc/ops/nms.h"
#include "torch_xla/csrc/ops/nonzero.h"
#include "torch_xla/csrc/ops/normal.h"
#include "torch_xla/csrc/ops/not_supported.h"
#include "torch_xla/csrc/ops/ops.h"
#include "torch_xla/csrc/ops/optimization_barrier.h"
#include "torch_xla/csrc/ops/permute.h"
#include "torch_xla/csrc/ops/prod.h"
#include "torch_xla/csrc/ops/put.h"
#include "torch_xla/csrc/ops/qr.h"
#include "torch_xla/csrc/ops/quant_tensor.h"
#include "torch_xla/csrc/ops/randperm.h"
#include "torch_xla/csrc/ops/recv.h"
#include "torch_xla/csrc/ops/reduce_scatter.h"
#include "torch_xla/csrc/ops/reflection_pad2d.h"
#include "torch_xla/csrc/ops/reflection_pad2d_backward.h"
#include "torch_xla/csrc/ops/replication_pad.h"
#include "torch_xla/csrc/ops/replication_pad_backward.h"
#include "torch_xla/csrc/ops/resize.h"
#include "torch_xla/csrc/ops/roll.h"
#include "torch_xla/csrc/ops/rrelu_with_noise.h"
#include "torch_xla/csrc/ops/rrelu_with_noise_backward.h"
#include "torch_xla/csrc/ops/scalar.h"
#include "torch_xla/csrc/ops/scatter.h"
#include "torch_xla/csrc/ops/scatter_add.h"
#include "torch_xla/csrc/ops/scatter_reduce.h"
#include "torch_xla/csrc/ops/select.h"
#include "torch_xla/csrc/ops/send.h"
#include "torch_xla/csrc/ops/sgd_optimizer_step.h"
#include "torch_xla/csrc/ops/softmax.h"
#include "torch_xla/csrc/ops/split.h"
#include "torch_xla/csrc/ops/squeeze.h"
#include "torch_xla/csrc/ops/stack.h"
#include "torch_xla/csrc/ops/std.h"
#include "torch_xla/csrc/ops/std_mean.h"
#include "torch_xla/csrc/ops/sum.h"
#include "torch_xla/csrc/ops/svd.h"
#include "torch_xla/csrc/ops/threshold.h"
#include "torch_xla/csrc/ops/threshold_backward.h"
#include "torch_xla/csrc/ops/topk.h"
#include "torch_xla/csrc/ops/tpu_custom_call.h"
#include "torch_xla/csrc/ops/triangular_solve.h"
#include "torch_xla/csrc/ops/uniform.h"
#include "torch_xla/csrc/ops/unsqueeze.h"
#include "torch_xla/csrc/ops/upsample_bilinear2d.h"
#include "torch_xla/csrc/ops/upsample_bilinear2d_backward.h"
#include "torch_xla/csrc/ops/upsample_nearest2d.h"
#include "torch_xla/csrc/ops/upsample_nearest2d_backward.h"
#include "torch_xla/csrc/ops/user_computation.h"
#include "torch_xla/csrc/ops/var.h"
#include "torch_xla/csrc/ops/var_mean.h"
#include "torch_xla/csrc/ops/view.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/metrics.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "torch_xla/csrc/shape_builder.h"
#include "torch_xla/csrc/tensor.h"
#include "torch_xla/csrc/tensor_ops.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/xla_graph_executor.h"
#include "xla/literal_util.h"
namespace torch_xla {
namespace tensor_methods {
namespace {
struct MinMaxValues {
torch::lazy::Value min;
torch::lazy::Value max;
};
torch::lazy::Value MaybeExpand(const torch::lazy::Value& input,
const xla::Shape& target_shape) {
if (GetXlaShape(input).dimensions() == target_shape.dimensions()) {
return input;
}
return torch::lazy::MakeNode<Expand>(
input, torch::lazy::ToVector<int64_t>(target_shape.dimensions()));
}
MinMaxValues GetMinMaxValues(const XLATensorPtr& tensor,
const c10::optional<at::Scalar>& min,
const c10::optional<at::Scalar>& max) {
XLA_CHECK(min || max)
<< "At least one of \'min\' or \'max\' must not be None";
xla::PrimitiveType raw_element_type = XlaTypeFromTorchType(tensor->dtype());
XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(raw_element_type);
auto shape = tensor->shape();
return {XLAGraphExecutor::Get()->GetIrValueForScalar(
min ? *min : min_max.min, shape.get().element_type(),
tensor->GetDevice()),
XLAGraphExecutor::Get()->GetIrValueForScalar(
max ? *max : min_max.max, shape.get().element_type(),
tensor->GetDevice())};
}
void CheckRank(const XLATensorPtr& t, int64_t expected_rank,
const std::string& tag, const std::string& arg_name,
int arg_number) {
int64_t actual_rank = t->shape().get().rank();
XLA_CHECK_EQ(actual_rank, expected_rank)
<< "Expected " << expected_rank << "-dimensional tensor, but got "
<< actual_rank << "-dimensional tensor for "
<< "argument #" << arg_number << " '" << arg_name << "'"
<< " (while checking arguments for " << tag << ")";
}
template <typename T>
void CheckShapeDimensions(const T& size) {
XLA_CHECK(std::all_of(size.begin(), size.end(), [](int64_t dim) {
return dim >= 0;
})) << "Dimensions cannot be negative numbers";
}
void CheckDimensionSize(const XLATensorPtr& t, int64_t dim,
int64_t expected_size, const std::string& tag,
const std::string& arg_name, int arg_number) {
int64_t dim_size = t->size(dim);
XLA_CHECK_EQ(t->size(dim), expected_size)
<< "Expected tensor to have size " << expected_size << " at dimension "
<< dim << ", but got size " << dim_size << " for "
<< "argument #" << arg_number << " '" << arg_name << "'"
<< " (while checking arguments for " << tag << ")";
}
void CheckBmmDimension(const std::string& tag, const XLATensorPtr& batch1,
const XLATensorPtr& batch2) {
// Consistent with the checks in bmm_out_or_baddbmm_.
CheckRank(batch1, 3, tag, "batch1", 1);
CheckRank(batch2, 3, tag, "batch2", 2);
CheckDimensionSize(batch2, 0, /*batch_size=*/batch1->size(0), tag, "batch2",
2);
CheckDimensionSize(batch2, 1, /*contraction_size=*/batch1->size(2), tag,
"batch2", 2);
}
std::vector<int64_t> GetExpandDimensions(const xla::Shape& shape,
std::vector<int64_t> dimensions) {
XLA_CHECK_GE(dimensions.size(), shape.rank()) << shape;
int64_t base = dimensions.size() - shape.rank();
for (size_t i = 0; i < shape.rank(); ++i) {
if (dimensions[base + i] == -1) {
dimensions[base + i] = shape.dimensions(i);
}
}
return dimensions;
}
// Resizes and / or checks whether a list is of the given size. The list is only
// resized if its size is 1. If it's empty, it's replaced with the provided
// default first.
std::vector<int64_t> CheckIntList(absl::Span<const int64_t> list, size_t length,
const std::string& name,
std::vector<int64_t> def = {}) {
std::vector<int64_t> result;
if (list.empty()) {
result = std::move(def);
} else {
result = torch::lazy::ToVector<int64_t>(list);
}
if (result.size() == 1 && length > 1) {
result.resize(length, result[0]);
return result;
}
XLA_CHECK_EQ(result.size(), length)
<< "Invalid length for the '" << name << "' attribute";
return result;
}
// Returns a 1-D shape for batch norm weight or bias based on the input shape.
xla::Shape BatchNormFeaturesShape(const XLATensorPtr& input) {
xla::PrimitiveType input_element_type =
MakeXlaPrimitiveType(input->dtype(), &input->GetDevice());
auto input_shape = input->shape();
return ShapeBuilder(input_element_type).Add(input_shape.get(), 1).Build();
}
// Returns the IR for the given input or the provided default value broadcasted
// to the default shape, if the input is undefined.
torch::lazy::Value GetIrValueOrDefault(
const XLATensorPtr& input, const at::Scalar& default_value,
const xla::Shape& default_shape, const torch::lazy::BackendDevice& device) {
return input ? input->GetIrValue()
: XLAGraphExecutor::Get()->GetIrValueForScalar(
default_value, default_shape, device);
}
// Returns the IR for the given input. If the IR is not a floating point value,
// cast it to the float_type.
torch::lazy::Value GetFloatingIrValue(const XLATensorPtr& input,
at::ScalarType float_type) {
torch::lazy::Value input_value = input->GetIrValue();
if (xla::primitive_util::IsIntegralType(
GetXlaShape(input_value).element_type())) {
input_value = torch::lazy::MakeNode<Cast>(input_value, float_type);
}
return input_value;
}
torch::lazy::Value GetBooleanIrValue(torch::lazy::Value input_value) {
if (GetXlaShape(input_value).element_type() != xla::PrimitiveType::PRED) {
input_value =
torch::lazy::MakeNode<Cast>(input_value, xla::PrimitiveType::PRED);
}
return input_value;
}
absl::optional<torch::lazy::Value> GetOptionalIrValue(
const XLATensorPtr& tensor) {
absl::optional<torch::lazy::Value> value;
if (tensor) {
value = tensor->GetIrValue();
}
return value;
}
ViewInfo CreateAsStridedViewInfo(const xla::Shape& input_shape,
std::vector<int64_t> size,
std::vector<int64_t> stride,
c10::optional<int64_t> storage_offset) {
xla::Shape result_shape = XlaHelpers::GetDynamicReshape(input_shape, size);
AsStridedInfo as_strided_info;
as_strided_info.stride = std::move(stride);
if (storage_offset) {
as_strided_info.offset = *storage_offset;
}
return ViewInfo(ViewInfo::Type::kAsStrided, std::move(result_shape),
input_shape, std::move(as_strided_info));
}
// Dispatches a comparison operator, setting the logical type of the result
// appropriately.
XLATensorPtr DispatchComparisonOp(c10::Symbol kind, const XLATensorPtr& input,
const at::Scalar& other) {
torch::lazy::NodePtr node = ComparisonOp(
kind, input->GetIrValue(),
XLAGraphExecutor::Get()->GetIrValueForScalar(other, input->GetDevice()));
return XLATensor::Create(node, input->GetDevice(), at::ScalarType::Bool);
}
// Same as above, with the second input a tensor as well.
XLATensorPtr DispatchComparisonOp(c10::Symbol kind, const XLATensorPtr& input,
const XLATensorPtr& other) {
torch::lazy::NodePtr node =
ComparisonOp(kind, input->GetIrValue(), other->GetIrValue());
return XLATensor::Create(node, input->GetDevice(), at::ScalarType::Bool);
}
} // namespace
//////////////////////////////////////////////////////////////////////////////
// XLA dedicated operators follows here, listed in alphabetical order.
//////////////////////////////////////////////////////////////////////////////
XLATensorPtr all_reduce(const XLATensorPtr& input, AllReduceType reduce_type,
double scale, std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
std::vector<torch::lazy::Value> input_values({input->GetIrValue()});
torch::lazy::NodePtr node = torch::lazy::MakeNode<AllReduce>(
reduce_type, input_values, GetAllReduceToken(input->GetDevice()), scale,
std::move(groups), pin_layout);
SetAllReduceToken(input->GetDevice(),
std::make_shared<torch::lazy::Value>(node, 1));
return input->CreateFrom(torch::lazy::Value(node, 0));
}
void all_reduce(const std::vector<XLATensorPtr>& inputs,
AllReduceType reduce_type, double scale,
std::vector<std::vector<int64_t>> groups, bool pin_layout) {
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch::lazy::MakeNode<AllReduce>(
reduce_type, input_values, GetAllReduceToken(inputs.front()->GetDevice()),
scale, std::move(groups), pin_layout);
for (size_t i = 0; i < inputs.size(); ++i) {
inputs[i]->SetInPlaceIrValue(torch::lazy::Value(node, i));
}
SetAllReduceToken(inputs.front()->GetDevice(),
std::make_shared<torch::lazy::Value>(node, inputs.size()));
}
std::pair<XLATensorPtr, torch::lazy::Value> reduce_scatter(
const XLATensorPtr& input, const torch::lazy::Value& token,
AllReduceType reduce_type, double scale, int64_t scatter_dim,
int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
torch::lazy::NodePtr node = torch::lazy::MakeNode<ReduceScatter>(
reduce_type, input->GetIrValue(), token, scale, scatter_dim, shard_count,
std::move(groups), pin_layout);
return {input->CreateFrom(torch::lazy::Value(node, 0)),
torch::lazy::Value(node, 1)};
}
torch::lazy::Value reduce_scatter_out(XLATensorPtr& output,
const XLATensorPtr& input,
const torch::lazy::Value& token,
AllReduceType reduce_type, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
torch::lazy::NodePtr node = torch::lazy::MakeNode<ReduceScatter>(
reduce_type, input->GetIrValue(), token, scale, scatter_dim, shard_count,
std::move(groups), pin_layout);
output->SetIrValue(torch::lazy::Value(node, 0));
return torch::lazy::Value(node, 1);
}
std::pair<std::vector<XLATensorPtr>, torch::lazy::Value>
reduce_scatter_coalesced(const std::vector<XLATensorPtr>& inputs,
const torch::lazy::Value& token,
AllReduceType reduce_type, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch::lazy::MakeNode<ReduceScatterCoalesced>(
reduce_type, input_values, token, scale, scatter_dim, shard_count,
std::move(groups), pin_layout);
std::vector<XLATensorPtr> result;
for (size_t i = 0; i < inputs.size(); ++i) {
result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i)));
}
return {result, torch::lazy::Value(node, inputs.size())};
}
torch::lazy::Value reduce_scatter_coalesced_out(
const std::vector<XLATensorPtr>& outputs,
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
AllReduceType reduce_type, double scale, int64_t scatter_dim,
int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch::lazy::MakeNode<ReduceScatterCoalesced>(
reduce_type, input_values, token, scale, scatter_dim, shard_count,
std::move(groups), pin_layout);
for (size_t i = 0; i < inputs.size(); ++i) {
outputs[i]->SetIrValue(torch::lazy::Value(node, i));
}
return torch::lazy::Value(node, inputs.size());
}
std::pair<XLATensorPtr, torch::lazy::Value> all_to_all(
const XLATensorPtr& input, const torch::lazy::Value& token,
int64_t split_dimension, int64_t concat_dimension, int64_t split_count,
std::vector<std::vector<int64_t>> groups, bool pin_layout) {
torch::lazy::NodePtr node = torch::lazy::MakeNode<AllToAll>(
input->GetIrValue(), token, split_dimension, concat_dimension,
split_count, std::move(groups), pin_layout);
return {input->CreateFrom(torch::lazy::Value(node, 0)),
torch::lazy::Value(node, 1)};
}
XLATensorPtr all_gather(const XLATensorPtr& input, int64_t dim,
int64_t shard_count,
std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
torch::lazy::NodePtr node = torch::lazy::MakeNode<AllGather>(
input->GetIrValue(), GetAllReduceToken(input->GetDevice()), dim,
shard_count, std::move(groups), pin_layout);
SetAllReduceToken(input->GetDevice(),
std::make_shared<torch::lazy::Value>(node, 1));
return input->CreateFrom(torch::lazy::Value(node, 0));
}
torch::lazy::Value all_gather_out(XLATensorPtr& output,
const XLATensorPtr& input,
const torch::lazy::Value& token, int64_t dim,
int64_t shard_count,
std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
torch::lazy::NodePtr node = torch::lazy::MakeNode<AllGather>(
input->GetIrValue(), token, dim, shard_count, std::move(groups),
pin_layout);
output->SetIrValue(torch::lazy::Value(node, 0));
return torch::lazy::Value(node, 1);
}
std::pair<std::vector<XLATensorPtr>, torch::lazy::Value> all_gather_coalesced(
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
int64_t dim, int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch::lazy::MakeNode<AllGatherCoalesced>(
input_values, token, dim, shard_count, std::move(groups), pin_layout);
std::vector<XLATensorPtr> result;
for (size_t i = 0; i < inputs.size(); ++i) {
result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i)));
}
return {result, torch::lazy::Value(node, inputs.size())};
}
torch::lazy::Value all_gather_coalesced_out(
std::vector<XLATensorPtr>& outputs, const std::vector<XLATensorPtr>& inputs,
const torch::lazy::Value& token, int64_t dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups, bool pin_layout) {
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch::lazy::MakeNode<AllGatherCoalesced>(
input_values, token, dim, shard_count, std::move(groups), pin_layout);
for (size_t i = 0; i < inputs.size(); ++i) {
outputs[i]->SetIrValue(torch::lazy::Value(node, i));
}
return torch::lazy::Value(node, inputs.size());
}
std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
const XLATensorPtr& input, const torch::lazy::Value& token,
std::vector<std::pair<int64_t, int64_t>> source_target_pairs) {
torch::lazy::NodePtr node = torch::lazy::MakeNode<CollectivePermute>(
input->GetIrValue(), token, std::move(source_target_pairs));
return {input->CreateFrom(torch::lazy::Value(node, 0)),
torch::lazy::Value(node, 1)};
}
void custom_sharding_(
const XLATensorPtr& input,
const std::shared_ptr<XLATensor::ShardingSpec>& sharding_spec) {
input->SetInPlaceIrValue(
torch::lazy::MakeNode<CustomSharding>(input->GetIrValue()));
input->SetShardingSpec(*sharding_spec);
}
void tpu_custom_call_(XLATensorPtr& output,
const std::vector<XLATensorPtr>& inputs,
const std::string& payload) {
std::vector<torch::lazy::Value> values;
for (const auto& input : inputs) {
values.push_back(input->GetIrValue());
}
output->SetInPlaceIrValue(torch::lazy::MakeNode<TpuCustomCall>(
values, output->shape().get(), payload));
}
XLATensorPtr get_dimensions_size(const XLATensorPtr& input,
std::vector<int64_t> dimensions) {
return input->CreateFrom(torch::lazy::MakeNode<GetDimensionsSize>(
input->GetIrValue(), std::move(dimensions)),
at::ScalarType::Int);
}
std::pair<XLATensorPtr, torch::lazy::Value> recv(
XLATensorPtr& output, const torch::lazy::Value& token, int64_t channel_id) {
torch::lazy::NodePtr node = torch::lazy::MakeNode<ir::ops::Recv>(
token, GetXlaShape(output->GetIrValue()), channel_id);
output->SetIrValue(torch::lazy::Value(node, 0));
return {output->CreateFrom(torch::lazy::Value(node, 0)),
torch::lazy::Value(node, 1)};
}
std::pair<XLATensorPtr, torch::lazy::Value> send(
const XLATensorPtr& input, const torch::lazy::Value& token,
int64_t channel_id) {
torch::lazy::NodePtr node = torch::lazy::MakeNode<ir::ops::Send>(
input->GetIrValue(), token, channel_id);
return {input->CreateFrom(torch::lazy::Value(node, 0)),
torch::lazy::Value(node, 1)};
}
void sgd_optimizer_step_(const XLATensorPtr& found_inf, XLATensorPtr& step,
XLATensorPtr& param, XLATensorPtr& buf,
const XLATensorPtr& d_p, double weight_decay,
double momentum, double lr, double dampening,
bool nesterov, bool maximize) {
torch::lazy::Value weight_decay_value =
XLAGraphExecutor::Get()->GetIrValueForScalar(weight_decay, param->shape(),
param->GetDevice());
torch::lazy::Value momentum_value =
XLAGraphExecutor::Get()->GetIrValueForScalar(momentum, param->shape(),
param->GetDevice());
torch::lazy::Value lr_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
maximize ? -lr : lr, param->shape(), param->GetDevice());
torch::lazy::Value dampening_value =
XLAGraphExecutor::Get()->GetIrValueForScalar(dampening, param->shape(),
param->GetDevice());
torch::lazy::NodePtr node = torch::lazy::MakeNode<SgdOptimizerStep>(
found_inf->GetIrValue(), step->GetIrValue(), param->GetIrValue(),
buf->GetIrValue(), d_p->GetIrValue(), weight_decay_value, momentum_value,
lr_value, dampening_value,
/*use_weight_decay=*/weight_decay != 0,
/*use_momentum=*/momentum != 0, /*use_nesterov=*/nesterov);
step->SetInPlaceIrValue(torch::lazy::Value(node, 0));
param->SetInPlaceIrValue(torch::lazy::Value(node, 1));
buf->SetInPlaceIrValue(torch::lazy::Value(node, 2));
}
void adam_optimizer_step_(const XLATensorPtr& found_inf, XLATensorPtr& step,
XLATensorPtr& param, const XLATensorPtr& grad,
XLATensorPtr& exp_avg, XLATensorPtr& exp_avg_sq,
XLATensorPtr& max_exp_avg_sq, double beta1,
double beta2, double lr, double weight_decay,
double eps, bool amsgrad, bool maximize,
bool use_adamw) {
torch::lazy::Value grad_value =
maximize ? mul(grad, -1)->GetIrValue() : grad->GetIrValue();
torch::lazy::Value beta1_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
beta1, found_inf->shape(), found_inf->GetDevice());
torch::lazy::Value beta2_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
beta2, found_inf->shape(), found_inf->GetDevice());
torch::lazy::Value lr_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
lr, found_inf->shape(), found_inf->GetDevice());
torch::lazy::Value weight_decay_value =
XLAGraphExecutor::Get()->GetIrValueForScalar(weight_decay, param->shape(),
param->GetDevice());
torch::lazy::Value eps_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
eps, param->shape(), param->GetDevice());
torch::lazy::NodePtr node = torch::lazy::MakeNode<AdamOptimizerStep>(
found_inf->GetIrValue(), step->GetIrValue(), param->GetIrValue(),
grad_value, exp_avg->GetIrValue(), exp_avg_sq->GetIrValue(),
max_exp_avg_sq->GetIrValue(), beta1_value, beta2_value, lr_value,
weight_decay_value, eps_value,
/*use_weight_decay=*/weight_decay != 0,
/*use_amsgrad=*/amsgrad, /*use_adamw=*/use_adamw);
step->SetInPlaceIrValue(torch::lazy::Value(node, 0));
param->SetInPlaceIrValue(torch::lazy::Value(node, 1));
exp_avg->SetInPlaceIrValue(torch::lazy::Value(node, 2));
exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 3));
if (amsgrad) {
max_exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 4));
}
}
std::vector<XLATensorPtr> user_computation(
const std::string& opname, absl::Span<const XLATensorPtr> inputs,
runtime::ComputationClient::ComputationPtr computation) {
XLA_CHECK(!inputs.empty());
std::vector<torch::lazy::Value> input_values;
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch::lazy::MakeNode<UserComputation>(
torch::lazy::OpKind::Get(opname), input_values, std::move(computation));
// Cast can be one of the user computation and we don't want to inherit the
// logical_element_type in this case
return inputs.front()->MakeOutputTensors(node,
/*inherit_logical_type=*/false);
}
//////////////////////////////////////////////////////////////////////////////
// ATEN operators follows here, listed in alphabetical order.
//////////////////////////////////////////////////////////////////////////////
void __ilshift__(XLATensorPtr& input, const at::Scalar& other) {
input->SetInPlaceIrValue(Lshift(input->GetIrValue(), other));
}
void __ilshift__(XLATensorPtr& input, const XLATensorPtr& other) {
input->SetInPlaceIrValue(Lshift(input->GetIrValue(), other->GetIrValue()));
}
void __irshift__(XLATensorPtr& input, const at::Scalar& other) {
input->SetInPlaceIrValue(Rshift(input->GetIrValue(), other));
}
void __irshift__(XLATensorPtr& input, const XLATensorPtr& other) {
input->SetInPlaceIrValue(Rshift(input->GetIrValue(), other->GetIrValue()));
}
XLATensorPtr __lshift__(const XLATensorPtr& input, const at::Scalar& other,
c10::optional<at::ScalarType> logical_element_type) {
return input->CreateFrom(Lshift(input->GetIrValue(), other),
logical_element_type);
}
XLATensorPtr __lshift__(const XLATensorPtr& input, const XLATensorPtr& other,
c10::optional<at::ScalarType> logical_element_type) {
return input->CreateFrom(Lshift(input->GetIrValue(), other->GetIrValue()),
logical_element_type);
}
XLATensorPtr __rshift__(const XLATensorPtr& input, const at::Scalar& other,
c10::optional<at::ScalarType> logical_element_type) {
return input->CreateFrom(Rshift(input->GetIrValue(), other),
logical_element_type);
}
XLATensorPtr __rshift__(const XLATensorPtr& input, const XLATensorPtr& other,
c10::optional<at::ScalarType> logical_element_type) {
return input->CreateFrom(Rshift(input->GetIrValue(), other->GetIrValue()),
logical_element_type);
}
std::tuple<XLATensorPtr, XLATensorPtr> adaptive_max_pool2d(
const XLATensorPtr& input, std::vector<int64_t> output_size) {
torch::lazy::NodePtr node = torch::lazy::MakeNode<AdaptiveMaxPool2d>(
input->GetIrValue(), output_size);
XLATensorPtr out = input->CreateFrom(torch::lazy::Value(node, 0));
XLATensorPtr indices =
input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long);
return std::make_tuple(std::move(out), std::move(indices));
}
XLATensorPtr adaptive_max_pool2d_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input) {
return input->CreateFrom(AdaptiveMaxPool2dBackward(grad_output->GetIrValue(),
input->GetIrValue()));
}
XLATensorPtr _adaptive_avg_pool2d(const XLATensorPtr& input,
std::vector<int64_t> output_size) {
return input->CreateFrom(torch::lazy::MakeNode<AdaptiveAvgPool2d>(
input->GetIrValue(), std::move(output_size)));
}
XLATensorPtr _adaptive_avg_pool2d_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input) {
return input->CreateFrom(torch::lazy::MakeNode<AdaptiveAvgPool2dBackward>(
grad_output->GetIrValue(), input->GetIrValue()));
}
void _amp_foreach_non_finite_check_and_unscale_(std::vector<XLATensorPtr> self,
XLATensorPtr& found_inf,
const XLATensorPtr& inv_scale) {
std::vector<torch::lazy::Value> inputs;
XLATensorPtr new_inv_scale = max(inv_scale);
for (const auto& x : self) {
inputs.push_back(x->GetIrValue());
}
torch::lazy::NodePtr node =
torch::lazy::MakeNode<AmpForachNonFiniteCheckAndUnscale>(
inputs, found_inf->GetIrValue(), new_inv_scale->GetIrValue());
for (size_t i = 0; i < self.size(); ++i) {
self[i]->SetInPlaceIrValue(torch::lazy::Value(node, i));
}
found_inf->SetInPlaceIrValue(torch::lazy::Value(node, self.size()));
}
void _amp_update_scale_(XLATensorPtr& current_scale,
XLATensorPtr& growth_tracker,
const XLATensorPtr& found_inf,
double scale_growth_factor, double scale_backoff_factor,
int growth_interval) {
torch::lazy::NodePtr node = torch::lazy::MakeNode<AmpUpdateScale>(
growth_tracker->GetIrValue(), current_scale->GetIrValue(),
found_inf->GetIrValue(), scale_growth_factor, scale_backoff_factor,
growth_interval);
growth_tracker->SetInPlaceIrValue(torch::lazy::Value(node, 1));
current_scale->SetInPlaceIrValue(torch::lazy::Value(node, 0));
}
XLATensorPtr abs(const XLATensorPtr& input) {
return input->CreateFrom(torch::lazy::MakeNode<Abs>(input->GetIrValue()));
}
XLATensorPtr add(const XLATensorPtr& input, const XLATensorPtr& other,
const at::Scalar& alpha,
c10::optional<at::ScalarType> logical_element_type) {
xla::Shape input_shape = input->shape().get();
xla::Shape other_shape = other->shape().get();
torch::lazy::Value constant;
const torch::lazy::BackendDevice& device = input->GetDevice();
if (!input_shape.is_dynamic() && !other_shape.is_dynamic()) {
constant = XLAGraphExecutor::Get()->GetIrValueForScalar(
alpha,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(other->dtype(), &device)),
logical_element_type, device);
} else {
SymIntElements sym_int_elements(other->GetIrValue());
constant = XLAGraphExecutor::Get()->GetIrValueForScalar(
alpha,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(other->dtype(), &device)),
sym_int_elements, logical_element_type, device);
}
return input->CreateFrom(input->GetIrValue() + other->GetIrValue() * constant,
logical_element_type);
}
XLATensorPtr add(const XLATensorPtr& input, const at::Scalar& other,
const at::Scalar& alpha,
c10::optional<at::ScalarType> logical_element_type) {
const torch::lazy::BackendDevice& device = input->GetDevice();
torch::lazy::Value other_constant =
XLAGraphExecutor::Get()->GetIrValueForScalar(
other,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(input->dtype(), &device)),
logical_element_type, device);
torch::lazy::Value alpha_constant =
XLAGraphExecutor::Get()->GetIrValueForScalar(
alpha,
xla::ShapeUtil::MakeScalarShape(
MakeXlaPrimitiveType(input->dtype(), &device)),
logical_element_type, device);
return input->CreateFrom(
input->GetIrValue() + other_constant * alpha_constant,
logical_element_type);
}
XLATensorPtr addmm(const XLATensorPtr& input, const XLATensorPtr& weight,
const XLATensorPtr& bias) {
return input->CreateFrom(AddMatMulOp(
input->GetIrValue(), weight->GetIrValue(), bias->GetIrValue()));
}
void arange_out(XLATensorPtr& out, const at::Scalar& start,
const at::Scalar& end, const at::Scalar& step,
at::ScalarType scalar_type) {
out->SetIrValue(ARange(start, end, step, scalar_type));
out->SetScalarType(scalar_type);
}
XLATensorPtr as_strided(const XLATensorPtr& input, std::vector<int64_t> size,
std::vector<int64_t> stride,
c10::optional<int64_t> storage_offset) {
// See Note: [Disabling functionalization]
if (runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
auto input_shape = input->shape();
return input->CreateViewTensor(CreateAsStridedViewInfo(
input_shape, std::move(size), std::move(stride), storage_offset));
}
return input->CreateFrom(torch::lazy::MakeNode<AsStrided>(
input->GetIrValue(), std::move(size), std::move(stride),
storage_offset.value_or(0)));
}
void as_strided_(XLATensorPtr& input, std::vector<int64_t> size,
std::vector<int64_t> stride,
c10::optional<int64_t> storage_offset) {
if (input->data()->view == nullptr) {
input->SetIrValue(torch::lazy::MakeNode<AsStrided>(
input->GetIrValue(), std::move(size), std::move(stride),
storage_offset.value_or(0)));
} else {
auto input_shape = input->shape();
input->SetSubView(CreateAsStridedViewInfo(
input_shape, std::move(size), std::move(stride), storage_offset));
}
}
XLATensorPtr avg_pool_nd(const XLATensorPtr& input, int64_t spatial_dim_count,
std::vector<int64_t> kernel_size,
std::vector<int64_t> stride,
std::vector<int64_t> padding, bool ceil_mode,
bool count_include_pad,
std::optional<int> divisor_override) {
kernel_size = CheckIntList(kernel_size, spatial_dim_count, "kernel_size");
stride = CheckIntList(stride, spatial_dim_count, "stride", kernel_size);
padding = CheckIntList(padding, spatial_dim_count, "padding");
return input->CreateFrom(torch::lazy::MakeNode<AvgPoolNd>(
input->GetIrValue(), spatial_dim_count, std::move(kernel_size),
std::move(stride), std::move(padding), ceil_mode, count_include_pad,
divisor_override));
}
XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop,
const XLATensorPtr& input,
int64_t spatial_dim_count,
std::vector<int64_t> kernel_size,
std::vector<int64_t> stride,
std::vector<int64_t> padding, bool ceil_mode,
bool count_include_pad) {
kernel_size = CheckIntList(kernel_size, spatial_dim_count, "kernel_size");
stride = CheckIntList(stride, spatial_dim_count, "stride", kernel_size);
padding = CheckIntList(padding, spatial_dim_count, "padding");
return out_backprop->CreateFrom(torch::lazy::MakeNode<AvgPoolNdBackward>(
out_backprop->GetIrValue(), input->GetIrValue(), spatial_dim_count,
std::move(kernel_size), std::move(stride), std::move(padding), ceil_mode,
count_include_pad));
}
XLATensorPtr baddbmm(const XLATensorPtr& input, const XLATensorPtr& batch1,
const XLATensorPtr& batch2, const at::Scalar& beta,
const at::Scalar& alpha) {
CheckBmmDimension(/*tag=*/"baddbmm", batch1, batch2);
torch::lazy::Value product_multiplier =
XLAGraphExecutor::Get()->GetIrValueForScalar(
alpha, batch1->shape().get().element_type(), batch1->GetDevice());
torch::lazy::Value bias_multiplier =
XLAGraphExecutor::Get()->GetIrValueForScalar(
beta, input->shape().get().element_type(), input->GetDevice());
return input->CreateFrom(torch::lazy::MakeNode<Baddbmm>(
input->GetIrValue(), batch1->GetIrValue(), batch2->GetIrValue(),
bias_multiplier, product_multiplier));
}
XLATensorPtr bernoulli(const XLATensorPtr& input, double probability) {
auto input_shape = input->shape();
return input->CreateFrom(torch::lazy::MakeNode<Bernoulli>(
XLAGraphExecutor::Get()->GetIrValueForScalar(probability, input_shape,
input->GetDevice()),
XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()),
input_shape.get()));
}
XLATensorPtr bernoulli(const XLATensorPtr& input) {
return input->CreateFrom(torch::lazy::MakeNode<Bernoulli>(
input->GetIrValue(),
XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()),
input->shape().get()));
}
void bernoulli_(XLATensorPtr& input, const XLATensorPtr& probability) {
input->SetInPlaceIrValue(torch::lazy::MakeNode<Bernoulli>(
probability->GetIrValue(),
XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()),
input->shape().get()));
}
XLATensorPtr bitwise_and(const XLATensorPtr& input, const XLATensorPtr& other) {
return input->CreateFrom(torch::lazy::MakeNode<BitwiseAndTensor>(
input->GetIrValue(), other->GetIrValue()));
}
XLATensorPtr bitwise_or(const XLATensorPtr& input, const XLATensorPtr& other) {
return input->CreateFrom(torch::lazy::MakeNode<BitwiseOrTensor>(
input->GetIrValue(), other->GetIrValue()));
}
XLATensorPtr bitwise_xor(const XLATensorPtr& input, const XLATensorPtr& other) {
return input->CreateFrom(torch::lazy::MakeNode<BitwiseXorTensor>(
input->GetIrValue(), other->GetIrValue()));
}
XLATensorPtr bmm(const XLATensorPtr& batch1, const XLATensorPtr& batch2) {
CheckBmmDimension(/*tag=*/"bmm", batch1, batch2);
return matmul(batch1, batch2);
}
std::vector<XLATensorPtr> broadcast_tensors(
absl::Span<const XLATensorPtr> tensors) {
XLA_CHECK(!tensors.empty()) << "broadcast_tensors cannot take an empty list";
std::vector<torch::lazy::Value> tensor_ir_values;
for (const auto& tensor : tensors) {
tensor_ir_values.push_back(tensor->GetIrValue());
}
torch::lazy::NodePtr node = BroadcastTensors(tensor_ir_values);
return tensors.front()->MakeOutputTensors(node);
}
XLATensorPtr cat(absl::Span<const XLATensorPtr> tensors, int64_t dim,
at::ScalarType dtype) {
// Shape checks for cat:
// - If not empty, every tensor shape must be the same.
// - Empty tensor passes but is simply ignore in implementation,
// e.g. ([2, 3, 5], [])
// - If empty dimension, other dimensions must be the same.
// e.g. ([4, 0, 32, 32], [4, 2, 32, 32], dim=1) passes.
// ([4, 0, 32, 32], [4, 2, 31, 32], dim=1) throws.
XLA_CHECK_GT(tensors.size(), 0);
std::vector<torch::lazy::Value> values;
std::vector<xla::Shape> shapes;
for (size_t i = 0; i < tensors.size(); ++i) {
xla::Shape tensor_shape = tensors[i]->shape();
if (tensor_shape.rank() == 1 && tensor_shape.dimensions()[0] == 0) {
continue;
}
dim = torch::lazy::GetCanonicalDimensionIndex(dim, tensor_shape.rank());
tensor_shape.DeleteDimension(dim);
if (!shapes.empty()) {
XLA_CHECK(xla::ShapeUtil::CompatibleIgnoringElementType(shapes.back(),
tensor_shape))
<< shapes.back() << " vs. " << tensor_shape;
}
shapes.push_back(tensor_shape);
values.push_back(tensors[i]->GetIrValue());
}
if (values.empty()) {
return tensors[0];
}
return tensors[0]->CreateFrom(torch::lazy::MakeNode<Cat>(values, dim, dtype),
dtype);
}
XLATensorPtr cdist_forward(const XLATensorPtr& x1, const XLATensorPtr& x2,
double p) {
torch::lazy::Value exponent_node =
XLAGraphExecutor::Get()->GetIrValueForScalar(p, x1->GetDevice());
torch::lazy::NodePtr node = torch::lazy::MakeNode<CdistForward>(
x1->GetIrValue(), x2->GetIrValue(), exponent_node,
/*use_hamming=*/p == 0.0,
/*use_chebyshev=*/std::isinf(p));
return x1->CreateFrom(node);
}
XLATensorPtr pdist_forward(const XLATensorPtr& input, double p) {
c10::optional<at::ScalarType> dtype = input->dtype_optional();
return input->CreateFrom(Pdist_forward(input->GetIrValue(), p, dtype));
}
XLATensorPtr celu(const XLATensorPtr& input, const at::Scalar& alpha) {
return input->CreateFrom(Celu(input->GetIrValue(), alpha));
}
void celu_(XLATensorPtr& input, const at::Scalar& alpha) {
input->SetInPlaceIrValue(Celu(input->GetIrValue(), alpha));
}
XLATensorPtr clamp(const XLATensorPtr& input,
const c10::optional<at::Scalar>& min,
const c10::optional<at::Scalar>& max) {
MinMaxValues min_max = GetMinMaxValues(input, min, max);
return input->CreateFrom(
Clamp(input->GetIrValue(), min_max.min, min_max.max));
}