Skip to content

Commit

Permalink
[Performance] Faster split, chunk and unbind (#563)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 21, 2023
1 parent a5bed01 commit 3689afa
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 44 deletions.
111 changes: 86 additions & 25 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from numbers import Number
from pathlib import Path
from textwrap import indent
from typing import Any, Callable, Iterable, Iterator, Sequence
from typing import Any, Callable, Iterable, Iterator, List, Sequence
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -585,8 +585,14 @@ def _convert_to_tensordict(self, dict_value: dict[str, Any]) -> T:
_is_memmap=self._is_memmap,
)

def _index_tensordict(self, index: IndexType) -> T:
def _index_tensordict(
self,
index: IndexType,
new_batch_size: torch.Size | None = None,
names: List[str] | None = None,
) -> T:
batch_size = self.batch_size
batch_dims = len(batch_size)
if (
not batch_size
and index is not None
Expand All @@ -595,10 +601,24 @@ def _index_tensordict(self, index: IndexType) -> T:
raise RuntimeError(
f"indexing a tensordict with td.batch_dims==0 is not permitted. Got index {index}."
)
names = self._get_names_idx(index)
batch_size = _getitem_batch_size(batch_size, index)
if names is None:
names = self._get_names_idx(index)
if new_batch_size is not None:
batch_size = new_batch_size
else:
batch_size = _getitem_batch_size(batch_size, index)
source = {}
for key, item in self.items():
if isinstance(item, TensorDict):
# this is the simplest case, we can pre-compute the batch size easily
new_batch_size = batch_size + item.batch_size[batch_dims:]
source[key] = item._index_tensordict(
index, new_batch_size=new_batch_size
)
else:
source[key] = _get_item(item, index)
return TensorDict(
source={key: _get_item(item, index) for key, item in self.items()},
source=source,
batch_size=batch_size,
device=self.device,
names=names,
Expand Down Expand Up @@ -650,54 +670,90 @@ def unbind(self, dim: int) -> tuple[T, ...]:
names = copy(self.names)
names = [name for i, name in enumerate(names) if i != dim]
out = []
unbind_self_dict = {key: tensor.unbind(dim) for key, tensor in self.items()}
# unbind_self_dict = {key: tensor.unbind(dim) for key, tensor in self.items()}
prefix = (slice(None),) * dim
for _idx in range(self.batch_size[dim]):
td = TensorDict(
{key: tensor[_idx] for key, tensor in unbind_self_dict.items()},
batch_size=batch_size,
_run_checks=False,
device=self.device,
_is_memmap=False,
_is_shared=False,
names=names,
)
_idx = prefix + (_idx,)
td = self._index_tensordict(_idx, new_batch_size=batch_size, names=names)
# td = TensorDict(
# {key: tensor[_idx] for key, tensor in unbind_self_dict.items()},
# batch_size=batch_size,
# _run_checks=False,
# device=self.device,
# _is_memmap=False,
# _is_shared=False,
# names=names,
# )
out.append(td)
if self.is_shared():
out[-1].share_memory_()
td._is_shared = True
elif self.is_memmap():
out[-1].memmap_()
td._is_memmap = True
return tuple(out)

def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBase]:
# we must use slices to keep the storage of the tensors
WRONG_TYPE = "split(): argument 'split_size' must be int or list of ints"
batch_size = self.batch_size
batch_sizes = []
batch_dims = len(batch_size)
if dim < 0:
dim = len(batch_size) + dim
if dim >= batch_dims or dim < 0:
raise IndexError(
f"Dimension out of range (expected to be in range of [-{self.batch_dims}, {self.batch_dims - 1}], but got {dim})"
)
max_size = batch_size[dim]
if isinstance(split_size, int):
idx0 = 0
idx1 = split_size
idx1 = min(max_size, split_size)
split_sizes = [slice(idx0, idx1)]
while idx1 < batch_size[dim]:
batch_sizes.append(
torch.Size(
tuple(
d if i != dim else idx1 - idx0 for i, d in enumerate(batch_size)
)
)
)
while idx1 < max_size:
idx0 = idx1
idx1 += split_size
idx1 = min(max_size, idx1 + split_size)
split_sizes.append(slice(idx0, idx1))
batch_sizes.append(
torch.Size(
tuple(
d if i != dim else idx1 - idx0
for i, d in enumerate(batch_size)
)
)
)
elif isinstance(split_size, (list, tuple)):
if len(split_size) == 0:
raise RuntimeError("Insufficient number of elements in split_size.")
try:
idx0 = 0
idx1 = split_size[0]
split_sizes = [slice(idx0, idx1)]
batch_sizes.append(
torch.Size(
tuple(
d if i != dim else idx1 - idx0
for i, d in enumerate(batch_size)
)
)
)
for idx in split_size[1:]:
idx0 = idx1
idx1 += idx
idx1 = min(max_size, idx1 + idx)
split_sizes.append(slice(idx0, idx1))
batch_sizes.append(
torch.Size(
tuple(
d if i != dim else idx1 - idx0
for i, d in enumerate(batch_size)
)
)
)
except TypeError:
raise TypeError(WRONG_TYPE)

Expand All @@ -708,7 +764,11 @@ def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBas
else:
raise TypeError(WRONG_TYPE)
index = (slice(None),) * dim
return tuple(self[index + (ss,)] for ss in split_sizes)
names = self.names
return tuple(
self._index_tensordict(index + (ss,), new_batch_size=bs, names=names)
for ss, bs in zip(split_sizes, batch_sizes)
)

def memmap_like(self, prefix: str | None = None) -> T:
def save_metadata(data: TensorDictBase, filepath, metadata=None):
Expand Down Expand Up @@ -2301,10 +2361,11 @@ def _create_nested_str(self, key):
# return self.to_tensordict()._apply_nest(*args, **kwargs)
_convert_to_tensordict = TensorDict._convert_to_tensordict

def _get_names_idx(self, *args, **kwargs):
raise NotImplementedError
_get_names_idx = TensorDict._get_names_idx

def _index_tensordict(self, index):
def _index_tensordict(self, index, new_batch_size=None, names=None):
# we ignore the names and new_batch_size which are only provided for
# efficiency purposes
return self._get_sub_tensordict(index)

def _remove_batch_dim(self, *args, **kwargs):
Expand Down
24 changes: 10 additions & 14 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Callable,
Generator,
Iterator,
List,
Optional,
OrderedDict,
overload,
Expand Down Expand Up @@ -526,19 +527,9 @@ def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]:
raise ValueError(
f"chunks must be a strictly positive integer, got {chunks}."
)
indices = []
_idx_start = 0
if chunks > 1:
interval = _idx_end = self.batch_size[dim] // chunks
else:
interval = _idx_end = self.batch_size[dim]
for c in range(chunks):
indices.append(slice(_idx_start, _idx_end))
_idx_start = _idx_end
_idx_end = _idx_end + interval if c < chunks - 2 else self.batch_size[dim]
if dim < 0:
dim = len(self.batch_size) + dim
return tuple(self[(*[slice(None) for _ in range(dim)], idx)] for idx in indices)
# fall back on split, using upper rounding
split_size = -(self.batch_size[dim] // -chunks)
return self.split(split_size, dim=dim)

@overload
def unsqueeze(self, dim: int) -> T:
Expand Down Expand Up @@ -3625,7 +3616,12 @@ def unflatten_keys(self, separator: str = ".", inplace: bool = False) -> T:
return self

@abc.abstractmethod
def _index_tensordict(self, index: IndexType) -> T:
def _index_tensordict(
self,
index: IndexType,
new_batch_size: torch.Size | None = None,
names: List[str] | None = None,
) -> T:
...

# Locking functionality
Expand Down
6 changes: 1 addition & 5 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2750,11 +2750,7 @@ def test_split(self, td_name, device, performer, dim):

for idx, split_td in enumerate(tds):
expected_split_dim_size = 1 if idx == rep else 2
expected_batch_size = [
expected_split_dim_size if dim_idx == dim else dim_size
for (dim_idx, dim_size) in enumerate(td.batch_size)
]

expected_batch_size = tensorsplit[idx].shape
# Test each split_td has the expected batch_size
assert split_td.batch_size == torch.Size(expected_batch_size)

Expand Down

1 comment on commit 3689afa

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 3689afa Previous: a5bed01 Ratio
benchmarks/common/common_ops_test.py::test_lock_stack_nested 121.30358415869281 iter/sec (stddev: 0.015246090448835372) 267.467234829666 iter/sec (stddev: 0.0008085172348222808) 2.20

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.