Skip to content

Commit

Permalink
[Doc] Update readme for v0.3 (#635)
Browse files Browse the repository at this point in the history
Co-authored-by: Shagun Sodhani <[email protected]>
  • Loading branch information
vmoens and shagunsodhani authored Jan 31, 2024
1 parent fdcc403 commit 2c73daa
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 31 deletions.
153 changes: 122 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@

[**Installation**](#installation) | [**General features**](#general) |
[**Tensor-like features**](#tensor-like-features) | [**Distributed capabilities**](#distributed-capabilities) |
[**TensorDict for functional programming using FuncTorch**](#tensordict-for-functional-programming-using-functorch) |
[**TensorDict for functional programming**](#tensordict-for-functional-programming) |
[**TensorDict for parameter serialization](#tensordict-for-parameter-serialization) |
[**Lazy preallocation**](#lazy-preallocation) | [**Nesting TensorDicts**](#nesting-tensordicts) | [**TensorClass**](#tensorclass)

`TensorDict` is a dictionary-like class that inherits properties from tensors,
Expand All @@ -39,10 +40,10 @@ in distributed settings.

The main purpose of TensorDict is to make code-bases more _readable_ and _modular_ by abstracting away tailored operations:
```python
for i, tensordict in enumerate(dataset):
for i, data in enumerate(dataset):
# the model reads and writes tensordicts
tensordict = model(tensordict)
loss = loss_module(tensordict)
data = model(data)
loss = loss_module(data)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Expand All @@ -60,7 +61,7 @@ A tensordict is primarily defined by its `batch_size` (or `shape`) and its key-v
```python
>>> from tensordict import TensorDict
>>> import torch
>>> tensordict = TensorDict({
>>> data = TensorDict({
... "key 1": torch.ones(3, 4, 5),
... "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
Expand All @@ -69,24 +70,44 @@ The `batch_size` and the first dimensions of each of the tensors must be complia
The tensors can be of any dtype and device. Optionally, one can restrict a tensordict to
live on a dedicated device, which will send each tensor that is written there:
```python
>>> tensordict = TensorDict({
>>> data = TensorDict({
... "key 1": torch.ones(3, 4, 5),
... "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4], device="cuda:0")
>>> tensordict["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert tensordict["key 3"].device is torch.device("cuda:0")
>>> data["key 3"] = torch.randn(3, 4, device="cpu")
>>> assert data["key 3"].device is torch.device("cuda:0")
```

But that is not all, you can also store nested values in a tensordict:
```python
>>> data["nested", "key"] = torch.zeros(3, 4) # the batch-size must match
```
and any nested tuple structure will be unravelled to make it easy to read code and
write ops programmatically:
```python
>>> data["nested", ("supernested", ("key",))] = torch.zeros(3, 4) # the batch-size must match
>>> assert (data["nested", "supernested", "key"] == 0).all()
>>> assert (("nested",), "supernested", (("key",),)) in data.keys(include_nested=True) # this works too!
```

You can also store non-tensor data in tensordicts:

```python
>>> data = TensorDict({"a-tensor": torch.randn(1, 2)}, batch_size=[1, 2])
>>> data["non-tensor"] = "a string!"
>>> assert data["non-tensor"] == "a string!"
```

### Tensor-like features

TensorDict objects can be indexed exactly like tensors. The resulting of indexing
a TensorDict is another TensorDict containing tensors indexed along the required dimension:
```python
>>> tensordict = TensorDict({
>>> data = TensorDict({
... "key 1": torch.ones(3, 4, 5),
... "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> sub_tensordict = tensordict[..., :2]
>>> sub_tensordict = data[..., :2]
>>> assert sub_tensordict.shape == torch.Size([3, 2])
>>> assert sub_tensordict["key 1"].shape == torch.Size([3, 2, 5])
```
Expand All @@ -107,15 +128,15 @@ Similarly, one can build tensordicts by stacking or concatenating single tensord

TensorDict instances can also be reshaped, viewed, squeezed and unsqueezed:
```python
>>> tensordict = TensorDict({
>>> data = TensorDict({
... "key 1": torch.ones(3, 4, 5),
... "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
... }, batch_size=[3, 4])
>>> print(tensordict.view(-1))
>>> print(data.view(-1))
torch.Size([12])
>>> print(tensordict.reshape(-1))
>>> print(data.reshape(-1))
torch.Size([12])
>>> print(tensordict.unsqueeze(-1))
>>> print(data.unsqueeze(-1))
torch.Size([3, 4, 1])
```

Expand All @@ -124,8 +145,16 @@ clone them, update them in-place or not, split them, unbind them, expand them et

If a functionality is missing, it is easy to call it using `apply()` or `apply_()`:
```python
tensordict_uniform = tensordict.apply(lambda tensor: tensor.uniform_())
tensordict_uniform = data.apply(lambda tensor: tensor.uniform_())
```

``apply()`` can also be great to filter a tensordict, for instance:
```python
data = TensorDict({"a": torch.tensor(1.0, dtype=torch.float), "b": torch.tensor(1, dtype=torch.int64)}, [])
data_float = data.apply(lambda x: x if x.dtype == torch.float else None) # contains only the "a" key
assert "b" not in data_float
```

### Distributed capabilities

Complex data structures can be cumbersome to synchronize in distributed settings.
Expand All @@ -146,39 +175,101 @@ When nodes share a common scratch space, the
can be used
to seamlessly send, receive and read a huge amount of data.

### TensorDict for functional programming using FuncTorch
### TensorDict for functional programming

We also provide an API to use TensorDict in conjunction with [FuncTorch](https://pytorch.org/functorch).
For instance, TensorDict makes it easy to concatenate model weights to do model ensembling:
```python
>>> from torch import nn
>>> from tensordict import TensorDict
>>> from tensordict.nn import make_functional
>>> import torch
>>> from torch import vmap
>>> layer1 = nn.Linear(3, 4)
>>> layer2 = nn.Linear(4, 4)
>>> model = nn.Sequential(layer1, layer2)
>>> params = TensorDict.from_module(model)
>>> # we represent the weights hierarchically
>>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(".")
>>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(".")
>>> params = make_functional(model)
>>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all()
>>> # Let's use our functional module
>>> x = torch.randn(10, 3)
>>> out = model(x, params=params) # params is the last arg (or kwarg)
>>> with params.to_module(model):
... out = model(x)
>>> # an ensemble of models: we stack params along the first dimension...
>>> params_stack = torch.stack([params, params], 0)
>>> # ... and use it as an input we'd like to pass through the model
>>> y = vmap(model, (None, 0))(x, params_stack)
>>> def func(x, params):
... with params.to_module(model):
... return model(x)
>>> y = vmap(func, (None, 0))(x, params_stack)
>>> print(y.shape)
torch.Size([2, 10, 4])
```

Moreover, tensordict modules are compatible with `torch.fx` and `torch.compile`,
Moreover, tensordict modules are compatible with `torch.fx` and (soon) `torch.compile`,
which means that you can get the best of both worlds: a codebase that is
both readable and future-proof as well as efficient and portable!

### TensorDict for parameter serialization and building datasets

TensorDict offers an API for parameter serialization that can be >3x faster than
regular calls to `torch.save(state_dict)`. Moreover, because tensors will be saved
independently on disk, you can deserialize your checkpoint on an arbitrary slice
of the model.

```python
>>> model = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 3))
>>> params = TensorDict.from_module(model)
>>> params.memmap("/path/to/saved/folder/", num_threads=16) # adjust num_threads for speed
>>> # load params
>>> params = TensorDict.load_memmap("/path/to/saved/folder/", num_threads=16)
>>> params.to_module(model) # load onto model
>>> params["0"].to_module(model[0]) # load on a slice of the model
>>> # in the latter case we could also have loaded only the slice we needed
>>> params0 = TensorDict.load_memmap("/path/to/saved/folder/0", num_threads=16)
>>> params0.to_module(model[0]) # load on a slice of the model
```

The same functionality can be used to access data in a dataset stored on disk.
Soring a single contiguous tensor on disk accessed through the `tensordict.MemoryMappedTensor`
primitive and reading slices of it is not only **much** faster than loading
single files one at a time but it's also easier and safer (because there is no pickling
or third-party library involved):

```python
# allocate memory of the dataset on disk
data = TensorDict({
"images": torch.zeros((128, 128, 3), dtype=torch.uint8),
"labels": torch.zeros((), dtype=torch.int)}, batch_size=[])
data = data.expand(1000000)
data = data.memmap_like("/path/to/dataset")
# ==> Fill your dataset here
# Let's get 3 items of our dataset:
data[torch.tensor([1, 10000, 500000])] # This is much faster than loading the 3 images independently
```

### Preprocessing with TensorDict.map

Preprocessing huge contiguous (or not!) datasets can be done via `TensorDict.map`
which will dispatch a task to various workers:

```python
import torch
from tensordict import TensorDict, MemoryMappedTensor
import tempfile

def process_data(data):
images = data.get("images").flip(-2).clone()
labels = data.get("labels") // 10
# we update the td inplace
data.set_("images", images) # flip image
data.set_("labels", labels) # cluster labels

if __name__ == "__main__":
# create data_preproc here
data_preproc = data.map(process_data, num_workers=4, chunksize=0, pbar=True) # process 1 images at a time
```

### Lazy preallocation

Expand All @@ -187,21 +278,21 @@ items varies according to the script configuration. TensorDict solves this in an
Assume you are working with a function `foo() -> TensorDict`, e.g.
```python
def foo():
tensordict = TensorDict({}, batch_size=[])
tensordict["a"] = torch.randn(3)
tensordict["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
return tensordict
data = TensorDict({}, batch_size=[])
data["a"] = torch.randn(3)
data["b"] = TensorDict({"c": torch.zeros(2)}, batch_size=[])
return data
```
and you would like to call this function repeatedly. You could do this in two ways.
The first would simply be to stack the calls to the function:
```python
tensordict = torch.stack([foo() for _ in range(N)])
data = torch.stack([foo() for _ in range(N)])
```
However, you could also choose to preallocate the tensordict:
```python
tensordict = TensorDict({}, batch_size=[N])
data = TensorDict({}, batch_size=[N])
for i in range(N):
tensordict[i] = foo()
data[i] = foo()
```
which also results in a tensordict (when `N = 10`)
```
Expand Down Expand Up @@ -233,16 +324,16 @@ batch size.
We can switch easily between hierarchical and flat representations.
For instance, the following code will result in a single-level tensordict with keys `"key 1"` and `"key 2.sub-key"`:
```python
>>> tensordict = TensorDict({
>>> data = TensorDict({
... "key 1": torch.ones(3, 4, 5),
... "key 2": TensorDict({"sub-key": torch.randn(3, 4, 5, 6)}, batch_size=[3, 4, 5])
... }, batch_size=[3, 4])
>>> tensordict_flatten = tensordict.flatten_keys(separator=".")
>>> tensordict_flatten = data.flatten_keys(separator=".")
```

Accessing nested tensordicts can be achieved with a single index:
```python
>>> sub_value = tensordict["key 2", "sub-key"]
>>> sub_value = data["key 2", "sub-key"]
```

## TensorClass
Expand Down
2 changes: 2 additions & 0 deletions tutorials/README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
README Tutos
============

Check a rendered version of the tutorials on tensordict doc: https://pytorch.org/tensordict

0 comments on commit 2c73daa

Please sign in to comment.