Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature,Refactor] More args in constructors, refactor free functions #1116

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/reference/tensordict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ or ``cat``.

cat
from_consolidated
from_any
from_dict
from_h5
from_module
from_modules
from_namedtuple
from_pytree
from_struct_array
from_tuple
fromkeys
is_batchedtensor
lazy_stack
Expand Down
10 changes: 6 additions & 4 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@
from tensordict._td import (
cat,
from_consolidated,
from_dict,
from_h5,
from_module,
from_modules,
from_namedtuple,
from_pytree,
from_struct_array,
fromkeys,
is_tensor_collection,
lazy_stack,
Expand All @@ -29,6 +25,12 @@
)

from tensordict.base import (
from_any,
from_dict,
from_h5,
from_namedtuple,
from_struct_array,
from_tuple,
get_defaults_to_none,
set_get_defaults_to_none,
TensorDictBase,
Expand Down
151 changes: 2 additions & 149 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,6 +2050,8 @@ def from_dict(
input_dict[key] = TensorDict.from_any(
value,
auto_batch_size=False,
device=device,
batch_size=batch_size,
)
# regular __init__ breaks because a tensor may have the same batch-size as the tensordict
out = cls(
Expand Down Expand Up @@ -4863,139 +4865,6 @@ def from_modules(
)


def from_dict(
input_dict, *others, batch_size=None, device=None, batch_dims=None, names=None
):
"""Returns a TensorDict created from a dictionary or another :class:`~.tensordict.TensorDict`.

If ``batch_size`` is not specified, returns the maximum batch size possible.

This function works on nested dictionaries too, or can be used to determine the
batch-size of a nested tensordict.

Args:
input_dict (dictionary, optional): a dictionary to use as a data source
(nested keys compatible).
batch_size (iterable of int, optional): a batch size for the tensordict.
device (torch.device or compatible type, optional): a device for the TensorDict.
batch_dims (int, optional): the ``batch_dims`` (ie number of leading dimensions
to be considered for ``batch_size``). Exclusinve with ``batch_size``.
Note that this is the __maximum__ number of batch dims of the tensordict,
a smaller number is tolerated.
names (list of str, optional): the dimension names of the tensordict.

Examples:
>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)}
>>> print(from_dict(input_dict))
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
>>> # nested dict: the nested TensorDict can have a different batch-size
>>> # as long as its leading dims match.
>>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}
>>> print(from_dict(input_dict))
TensorDict(
fields={
a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
>>> # we can also use this to work out the batch sie of a tensordict
>>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, [])
>>> print(
from_dict(input_td))
TensorDict(
fields={
a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 4]),
device=None,
is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)

"""
return TensorDict.from_dict(
input_dict,
*others,
batch_size=batch_size,
device=device,
batch_dims=batch_dims,
names=names,
)


def from_namedtuple(named_tuple, *, auto_batch_size: bool = False):
"""Converts a namedtuple to a TensorDict recursively.

Keyword Args:
auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically.
Defaults to ``False``.

Examples:
>>> from tensordict import TensorDict, from_namedtuple
>>> import torch
>>> data = TensorDict({
... "a_tensor": torch.zeros((3)),
... "nested": {"a_tensor": torch.zeros((3)), "a_string": "zero!"}}, [3])
>>> nt = data.to_namedtuple()
>>> print(nt)
GenericDict(a_tensor=tensor([0., 0., 0.]), nested=GenericDict(a_tensor=tensor([0., 0., 0.]), a_string='zero!'))
>>> from_namedtuple(nt, auto_batch_size=True)
TensorDict(
fields={
a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
nested: TensorDict(
fields={
a_string: NonTensorData(data=zero!, batch_size=torch.Size([3]), device=None),
a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)

"""
return TensorDict.from_namedtuple(named_tuple, auto_batch_size=auto_batch_size)


def from_struct_array(struct_array: np.ndarray, device: torch.device | None = None):
"""Converts a structured numpy array to a TensorDict.

The content of the resulting TensorDict will share the same memory content as the numpy array (it is a zero-copy
operation). Changing values of the structured numpy array in-place will affect the content of the TensorDict.

Examples:
>>> x = np.array(
... [("Rex", 9, 81.0), ("Fido", 3, 27.0)],
... dtype=[("name", "U10"), ("age", "i4"), ("weight", "f4")],
... )
>>> td = from_struct_array(x)
>>> x_recon = td.to_struct_array()
>>> assert (x_recon == x).all()
>>> assert x_recon.shape == x.shape
>>> # Try modifying x age field and check effect on td
>>> x["age"] += 1
>>> assert (td["age"] == np.array([10, 4])).all()

"""
return TensorDict.from_struct_array(struct_array, device=device)


def from_pytree(
pytree,
*,
Expand Down Expand Up @@ -5060,22 +4929,6 @@ def from_pytree(
)


def from_h5(
filename,
mode="r",
):
"""Creates a PersistentTensorDict from a h5 file.

This function will automatically determine the batch-size for each nested
tensordict.

Args:
filename (str): the path to the h5 file.
mode (str, optional): reading mode. Defaults to ``"r"``.
"""
return TensorDict.from_h5(filename, mode="r")


def stack(input, dim=0, *, out=None):
"""Stacks tensordicts into a single tensordict along the given dimension.

Expand Down
Loading
Loading