Skip to content

Commit

Permalink
[Feature,Refactor] More args in constructors, refactor free functions
Browse files Browse the repository at this point in the history
ghstack-source-id: 35e2444bb5d4bf92b78437063e2f5aec83651713
Pull Request resolved: #1116
  • Loading branch information
vmoens committed Nov 28, 2024
1 parent 004f979 commit 91737d0
Show file tree
Hide file tree
Showing 6 changed files with 506 additions and 209 deletions.
2 changes: 2 additions & 0 deletions docs/source/reference/tensordict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ or ``cat``.

cat
from_consolidated
from_any
from_dict
from_h5
from_module
from_modules
from_namedtuple
from_pytree
from_struct_array
from_tuple
fromkeys
is_batchedtensor
lazy_stack
Expand Down
10 changes: 6 additions & 4 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@
from tensordict._td import (
cat,
from_consolidated,
from_dict,
from_h5,
from_module,
from_modules,
from_namedtuple,
from_pytree,
from_struct_array,
fromkeys,
is_tensor_collection,
lazy_stack,
Expand All @@ -29,6 +25,12 @@
)

from tensordict.base import (
from_any,
from_dict,
from_h5,
from_namedtuple,
from_struct_array,
from_tuple,
get_defaults_to_none,
set_get_defaults_to_none,
TensorDictBase,
Expand Down
151 changes: 2 additions & 149 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,6 +2050,8 @@ def from_dict(
input_dict[key] = TensorDict.from_any(
value,
auto_batch_size=False,
device=device,
batch_size=batch_size,
)
# regular __init__ breaks because a tensor may have the same batch-size as the tensordict
out = cls(
Expand Down Expand Up @@ -4863,139 +4865,6 @@ def from_modules(
)


def from_dict(
input_dict, *others, batch_size=None, device=None, batch_dims=None, names=None
):
"""Returns a TensorDict created from a dictionary or another :class:`~.tensordict.TensorDict`.
If ``batch_size`` is not specified, returns the maximum batch size possible.
This function works on nested dictionaries too, or can be used to determine the
batch-size of a nested tensordict.
Args:
input_dict (dictionary, optional): a dictionary to use as a data source
(nested keys compatible).
batch_size (iterable of int, optional): a batch size for the tensordict.
device (torch.device or compatible type, optional): a device for the TensorDict.
batch_dims (int, optional): the ``batch_dims`` (ie number of leading dimensions
to be considered for ``batch_size``). Exclusinve with ``batch_size``.
Note that this is the __maximum__ number of batch dims of the tensordict,
a smaller number is tolerated.
names (list of str, optional): the dimension names of the tensordict.
Examples:
>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)}
>>> print(from_dict(input_dict))
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
>>> # nested dict: the nested TensorDict can have a different batch-size
>>> # as long as its leading dims match.
>>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}
>>> print(from_dict(input_dict))
TensorDict(
fields={
a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
>>> # we can also use this to work out the batch sie of a tensordict
>>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, [])
>>> print(
from_dict(input_td))
TensorDict(
fields={
a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
"""
return TensorDict.from_dict(
input_dict,
*others,
batch_size=batch_size,
device=device,
batch_dims=batch_dims,
names=names,
)


def from_namedtuple(named_tuple, *, auto_batch_size: bool = False):
"""Converts a namedtuple to a TensorDict recursively.
Keyword Args:
auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically.
Defaults to ``False``.
Examples:
>>> from tensordict import TensorDict, from_namedtuple
>>> import torch
>>> data = TensorDict({
... "a_tensor": torch.zeros((3)),
... "nested": {"a_tensor": torch.zeros((3)), "a_string": "zero!"}}, [3])
>>> nt = data.to_namedtuple()
>>> print(nt)
GenericDict(a_tensor=tensor([0., 0., 0.]), nested=GenericDict(a_tensor=tensor([0., 0., 0.]), a_string='zero!'))
>>> from_namedtuple(nt, auto_batch_size=True)
TensorDict(
fields={
a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a_string: NonTensorData(data=zero!, batch_size=torch.Size([3]), device=None),
a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
"""
return TensorDict.from_namedtuple(named_tuple, auto_batch_size=auto_batch_size)


def from_struct_array(struct_array: np.ndarray, device: torch.device | None = None):
"""Converts a structured numpy array to a TensorDict.
The content of the resulting TensorDict will share the same memory content as the numpy array (it is a zero-copy
operation). Changing values of the structured numpy array in-place will affect the content of the TensorDict.
Examples:
>>> x = np.array(
... [("Rex", 9, 81.0), ("Fido", 3, 27.0)],
... dtype=[("name", "U10"), ("age", "i4"), ("weight", "f4")],
... )
>>> td = from_struct_array(x)
>>> x_recon = td.to_struct_array()
>>> assert (x_recon == x).all()
>>> assert x_recon.shape == x.shape
>>> # Try modifying x age field and check effect on td
>>> x["age"] += 1
>>> assert (td["age"] == np.array([10, 4])).all()
"""
return TensorDict.from_struct_array(struct_array, device=device)


def from_pytree(
pytree,
*,
Expand Down Expand Up @@ -5060,22 +4929,6 @@ def from_pytree(
)


def from_h5(
filename,
mode="r",
):
"""Creates a PersistentTensorDict from a h5 file.
This function will automatically determine the batch-size for each nested
tensordict.
Args:
filename (str): the path to the h5 file.
mode (str, optional): reading mode. Defaults to ``"r"``.
"""
return TensorDict.from_h5(filename, mode="r")


def stack(input, dim=0, *, out=None):
"""Stacks tensordicts into a single tensordict along the given dimension.
Expand Down
Loading

0 comments on commit 91737d0

Please sign in to comment.