Skip to content

Commit

Permalink
For xla tensors, use an alternative way to get a unique id (#25802)
Browse files Browse the repository at this point in the history
* For xla tensors, use an alternative way to get a unique id

Because xla tensors don't have storage.

* add is_torch_tpu_available check
  • Loading branch information
qihqi authored Aug 31, 2023
1 parent 716bb2e commit f8468b4
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from safetensors.torch import storage_ptr, storage_size
from torch import nn

from .utils import logging
from .utils import is_torch_tpu_available, logging


ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
Expand Down Expand Up @@ -285,4 +285,15 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
"""
return tensor.device, storage_ptr(tensor), storage_size(tensor)
if tensor.device.type == "xla" and is_torch_tpu_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
# this is a XLA tensor, it must be created using torch_xla's
# device. So the following import is safe:
import torch_xla

unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
else:
unique_id = storage_ptr(tensor)

return tensor.device, unique_id, storage_size(tensor)

0 comments on commit f8468b4

Please sign in to comment.