Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 14, 2023
1 parent 0528111 commit 7413381
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 37 deletions.
11 changes: 2 additions & 9 deletions benchmarks/nn/functional_benchmarks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,8 @@
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential
from tensordict.nn.functional_modules import make_functional
from torch import nn

try:
from torch import vmap
except ImportError:
try:
from functorch import vmap
except ImportError:
raise RuntimeError("vmap couldn't be found, check pytorch version.")

from torch import nn, vmap


def make_net():
Expand Down
21 changes: 9 additions & 12 deletions docs/source/reference/tensordict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,32 @@ regular pytorch tensors.
Memory-mapped tensors
---------------------

:obj:`tensordict` offers the :obj:`MemmapTensor` primitive which allows you to work
with tensors stored in physical memory in a handy way. The main advantages of :obj:`MemmapTensor`
:obj:`tensordict` offers the :class:`~tensordict.MemoryMappedTensor` primitive which allows you to work
with tensors stored in physical memory in a handy way. The main advantages of :class:`~tensordict.MemoryMappedTensor`
are its easiness of construction (no need to handle the storage of a tensor), the possibility to
work with big contiguous data that would not fit in memory, an efficient (de)serialization across processes and
efficient indexing of stored tensors.

If all workers have access to the same storage, passing a :obj:`MemmapTensor` will just consist in passing
a reference to a file on disk plus a bunch of extra meta-data for reconstructing it when
sent across processes or workers on a same machine (both in multiprocess and distributed settings).
The same goes with indexed memory-mapped tensors.
If all workers have access to the same storage, passing a :class:`~tensordict.MemoryMappedTensor`
will just consist in passing a reference to a file on disk plus a bunch of
extra meta-data for reconstructing it when sent across processes or workers on
a same machine (both in multiprocess and distributed settings). The same goes
with indexed memory-mapped tensors.

Indexing memory-mapped tensors is much faster than loading several independent files from
the disk and does not require to load the full content of the array in memory.
However, physical storage of PyTorch tensors should not be any different:

.. code-block:: Python
>>> my_images = MemmapTensor(1_000_000, 3, 480, 480, dtype=torch.unint8)
>>> my_images = MemoryMappedTensor.empty((1_000_000, 3, 480, 480), dtype=torch.unint8)
>>> mini_batch = my_images[:10] # just reads the first 10 images of the dataset
>>> mini_batch = my_images.as_tensor()[:10] # similar but using pytorch tensors directly
The main difference between the two examples above is that, in the first case, indexing
returns a :obj:`MemmapTensor` instance, whereas in the second a :ob:`torch.Tensor` is returned.
.. autosummary::
:toctree: generated/
:template: td_template.rst

MemmapTensor
MemoryMappedTensor

Utils
-----
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ per-file-ignores =
./hubconf.py: F401
test/smoke_test.py: F401
test/smoke_test_deps.py: F401
test_*.py: E731, E266
test_*.py: E731, E266, TOR101
exclude = venv
extend-select = B901, C401, C408, C409

Expand Down
8 changes: 4 additions & 4 deletions tutorials/sphinx_tuto/data_fashion.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
train_dataloader = DataLoader(training_data, batch_size=batch_size) # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size) # noqa: TOR401

train_dataloader_td = DataLoader(
train_dataloader_td = DataLoader( # noqa: TOR401
training_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_td = DataLoader(
test_dataloader_td = DataLoader( # noqa: TOR401
test_data_td, batch_size=batch_size, collate_fn=lambda x: x
)

Expand Down
8 changes: 4 additions & 4 deletions tutorials/sphinx_tuto/tensorclass_fashion.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def from_dataset(cls, dataset, device=None):

batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
train_dataloader = DataLoader(training_data, batch_size=batch_size) # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size) # noqa: TOR401

train_dataloader_tc = DataLoader(
train_dataloader_tc = DataLoader( # noqa: TOR401
training_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_tc = DataLoader(
test_dataloader_tc = DataLoader( # noqa: TOR401
test_data_tc, batch_size=batch_size, collate_fn=lambda x: x
)

Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx_tuto/tensorclass_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,12 @@ def __call__(self, x: ImageNetData):
num_workers=NUM_WORKERS,
)

train_dataloader_tc = DataLoader(
train_dataloader_tc = DataLoader( # noqa: TOR401
train_data_tc,
batch_size=batch_size,
collate_fn=Collate(collate_transform, device),
)
val_dataloader_tc = DataLoader(
val_dataloader_tc = DataLoader( # noqa: TOR401
val_data_tc,
batch_size=batch_size,
collate_fn=Collate(device=device),
Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx_tuto/tensordict_module_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def forward(self, x):

params_expand = [p.expand(3, *p.shape) for p in params]
buffers_expand = [p.expand(3, *p.shape) for p in buffers]
print(functorch.vmap(func, (0, 0, None))(params_expand, buffers_expand, tensordict))
print(torch.vmap(func, (0, 0, None))(params_expand, buffers_expand, tensordict))

###############################################################################
# We can also use the native :func:`make_functional <tensordict.nn.make_functional>`
Expand All @@ -74,5 +74,5 @@ def forward(self, x):
params = make_functional(model)
# we stack two groups of parameters to show the vmap usage:
params = torch.stack([params, params.apply(lambda x: torch.zeros_like(x))], 0)
result_td = functorch.vmap(model, (None, 0))(tensordict, params)
result_td = torch.vmap(model, (None, 0))(tensordict, params)
print("the output tensordict shape is: ", result_td.shape)
6 changes: 3 additions & 3 deletions tutorials/sphinx_tuto/tensordict_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
chunks = tensordict.split([3, 1], dim=1)
assert chunks[0].batch_size == torch.Size([3, 3])
assert chunks[1].batch_size == torch.Size([3, 1])
torch.testing.assert_allclose(chunks[0]["a"], tensordict["a"][:, :-1])
torch.testing.assert_close(chunks[0]["a"], tensordict["a"][:, :-1])

##############################################################################
# .. note::
Expand Down Expand Up @@ -108,7 +108,7 @@
slices = tensordict.unbind(dim=1)
assert len(slices) == 4
assert all(s.batch_size == torch.Size([3]) for s in slices)
torch.testing.assert_allclose(slices[0]["a"], tensordict["a"][:, 0])
torch.testing.assert_close(slices[0]["a"], tensordict["a"][:, 0])

##############################################################################
# Stacking and concatenating
Expand Down Expand Up @@ -181,7 +181,7 @@

exp_tensordict = tensordict.expand(2, *tensordict.batch_size)
assert exp_tensordict.batch_size == torch.Size([2, 3, 4])
torch.testing.assert_allclose(exp_tensordict["a"][0], exp_tensordict["a"][1])
torch.testing.assert_close(exp_tensordict["a"][0], exp_tensordict["a"][1])

##############################################################################
# Squeezing and Unsqueezing ``TensorDict``
Expand Down

0 comments on commit 7413381

Please sign in to comment.