Skip to content

Commit

Permalink
Smarter check for is_tensor (huggingface#25871)
Browse files Browse the repository at this point in the history
* Smarter check for

* Use protected functions

* Do others too

* Apply suggestions from code review

Co-authored-by: Arthur <[email protected]>

* Address review comments

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
2 people authored and parambharat committed Sep 26, 2023
1 parent 0a298e5 commit 87808fb
Showing 1 changed file with 80 additions and 29 deletions.
109 changes: 80 additions & 29 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,31 +71,64 @@ def strtobool(val):
raise ValueError(f"invalid truth value {val!r}")


def infer_framework_from_repr(x):
"""
Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the
frameworks in a smart order, without the need to import the frameworks).
"""
representation = repr(x)
if representation.startswith("tensor"):
return "pt"
elif "tf.Tensor" in representation:
return "tf"
elif representation.startswith("Array"):
return "jax"
elif representation.startswith("array"):
return "np"


def _get_frameworks_and_test_func(x):
"""
Returns an (ordered since we are in Python 3.7+) dictionary framework to test function, which places the framework
we can guess from the repr first, then Numpy, then the others.
"""
framework_to_test = {
"pt": is_torch_tensor,
"tf": is_tf_tensor,
"jax": is_jax_tensor,
"np": is_numpy_array,
}
preferred_framework = infer_framework_from_repr(x)
# We will test this one first, then numpy, then the others.
frameworks = [] if preferred_framework is None else [preferred_framework]
if preferred_framework != "np":
frameworks.append("np")
frameworks.extend([f for f in framework_to_test if f not in [preferred_framework, "np"]])
return {f: framework_to_test[f] for f in frameworks}


def is_tensor(x):
"""
Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray`.
Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray` in the order
defined by `infer_framework_from_repr`
"""
if is_torch_fx_proxy(x):
return True
if is_torch_available():
import torch

if isinstance(x, torch.Tensor):
# This gives us a smart order to test the frameworks with the corresponding tests.
framework_to_test_func = _get_frameworks_and_test_func(x)
for test_func in framework_to_test_func.values():
if test_func(x):
return True
if is_tf_available():
import tensorflow as tf

if isinstance(x, tf.Tensor):
return True
# Tracers
if is_torch_fx_proxy(x):
return True

if is_flax_available():
import jax.numpy as jnp
from jax.core import Tracer

if isinstance(x, (jnp.ndarray, Tracer)):
if isinstance(x, Tracer):
return True

return isinstance(x, np.ndarray)
return False


def _is_numpy(x):
Expand Down Expand Up @@ -200,17 +233,27 @@ def to_py_obj(obj):
"""
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
"""

framework_to_py_obj = {
"pt": lambda obj: obj.detach().cpu().tolist(),
"tf": lambda obj: obj.numpy().tolist(),
"jax": lambda obj: np.asarray(obj).tolist(),
"np": lambda obj: obj.tolist(),
}

if isinstance(obj, (dict, UserDict)):
return {k: to_py_obj(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [to_py_obj(o) for o in obj]
elif is_tf_tensor(obj):
return obj.numpy().tolist()
elif is_torch_tensor(obj):
return obj.detach().cpu().tolist()
elif is_jax_tensor(obj):
return np.asarray(obj).tolist()
elif isinstance(obj, (np.ndarray, np.number)): # tolist also works on 0d np arrays

# This gives us a smart order to test the frameworks with the corresponding tests.
framework_to_test_func = _get_frameworks_and_test_func(obj)
for framework, test_func in framework_to_test_func.items():
if test_func(obj):
return framework_to_py_obj[framework](obj)

# tolist also works on 0d np arrays
if isinstance(obj, np.number):
return obj.tolist()
else:
return obj
Expand All @@ -220,18 +263,26 @@ def to_numpy(obj):
"""
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a Numpy array.
"""

framework_to_numpy = {
"pt": lambda obj: obj.detach().cpu().numpy(),
"tf": lambda obj: obj.numpy(),
"jax": lambda obj: np.asarray(obj),
"np": lambda obj: obj,
}

if isinstance(obj, (dict, UserDict)):
return {k: to_numpy(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return np.array(obj)
elif is_tf_tensor(obj):
return obj.numpy()
elif is_torch_tensor(obj):
return obj.detach().cpu().numpy()
elif is_jax_tensor(obj):
return np.asarray(obj)
else:
return obj

# This gives us a smart order to test the frameworks with the corresponding tests.
framework_to_test_func = _get_frameworks_and_test_func(obj)
for framework, test_func in framework_to_test_func.items():
if test_func(obj):
return framework_to_numpy[framework](obj)

return obj


class ModelOutput(OrderedDict):
Expand Down

0 comments on commit 87808fb

Please sign in to comment.