diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 45a7b2796c9..2bbf148536d 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1377,6 +1377,32 @@ def test_data_loader_with_non_batch_size_and_mini_batch(self): ): data, _ = iter(train_device_loader).__next__() + def test_fallback(self): + device = torch_xla.device() + + theta: float = 10000 + dim = 16 + end = 2048 + + torch_xla.sync() + freqs = 1.0 / ( + theta + **(torch.arange(0, dim, 2, device=device)[:(dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs, device=device), + freqs) # complex64 + # torch.polar will fallback on CPU, the result tensor should not have any sharding spec + self.assertIn("ShardingSpec: None", + torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis)) + # it will be on a CPU tensor, the sharding spec is not specified so it won't be move to device yet + self.assertIn("Tensor on host: with size [2048, 8]", + torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis)) + torch_xla.sync() + # data should be on device and replicated now + self.assertIn("Data Shape: c64[2048,8]\n OpSharding: {replicated}", + torch_xla._XLAC._get_xla_tensor_debug_info(freqs_cis)) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 77684e9bd34..095a6ce4163 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -562,6 +562,7 @@ void XLATensor::UpdateFromTensor(at::Tensor tensor, bool sync) { at::Tensor coyped_tensor = torch::lazy::CopyTensor(tensor, dtype()); SetTensorData(coyped_tensor); data()->handle = nullptr; + data()->sharding = nullptr; AssignIrValue(torch::lazy::Value()); if (data()->view != nullptr) { torch::lazy::Value ir_value = GetIrValueForTensor(coyped_tensor, device);