Skip to content

Commit

Permalink
[Performance] Better shared/memmap inheritance and faster exclude (#621)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 17, 2024
1 parent c85acfb commit 99eff33
Show file tree
Hide file tree
Showing 8 changed files with 779 additions and 328 deletions.
11 changes: 4 additions & 7 deletions benchmarks/distributed/distributed_benchmark_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest
import torch

from tensordict import MemoryMappedTensor, TensorDict
from tensordict import TensorDict
from torch.distributed import rpc

MAIN_NODE = "Main"
Expand Down Expand Up @@ -66,15 +66,12 @@ def exec_distributed_test(rank_node):
# create a tensordict is 1Gb big, stored on disk, assuming that both nodes have access to /tmp/
tensordict = TensorDict(
{
"memmap": MemoryMappedTensor.empty(
(1000, 640, 640, 3),
dtype=torch.uint8,
filename=tmpdir / "mmap.memmap",
"memmap": torch.empty((), dtype=torch.uint8).expand(
(1000, 640, 640, 3)
)
},
[1000],
_is_memmap=True,
)
).memmap_(tmpdir, copy_existing=False)
assert tensordict.is_memmap()

while True:
Expand Down
180 changes: 131 additions & 49 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,6 @@ def __init__(
hook_in: callable | None = None,
batch_size: Sequence[int] | None = None, # TODO: remove
) -> None:
self._is_shared = False
self._is_memmap = False
self._is_locked = None

# sanity check
Expand Down Expand Up @@ -226,6 +224,15 @@ def __init__(
if batch_size is not None and batch_size != self.batch_size:
raise RuntimeError("batch_size does not match self.batch_size.")

# These attributes should never be set
@property
def _is_shared(self):
return all(td._is_shared for td in self.tensordicts)

@property
def _is_memmap(self):
return all(td._is_memmap for td in self.tensordicts)

@property
@cache # noqa: B019
def _has_exclusive_keys(self):
Expand Down Expand Up @@ -1225,22 +1232,22 @@ def empty(self, recurse=False) -> T:
stack_dim=self.stack_dim,
)

def clone(self, recurse: bool = True) -> T:
def _clone(self, recurse: bool = True) -> T:
if recurse:
# This could be optimized using copy but we must be careful with
# metadata (_is_shared etc)
out = LazyStackedTensorDict(
*[td.clone() for td in self.tensordicts],
result = LazyStackedTensorDict(
*[td._clone() for td in self.tensordicts],
stack_dim=self.stack_dim,
)
else:
out = LazyStackedTensorDict(
*[td.clone(recurse=False) for td in self.tensordicts],
result = LazyStackedTensorDict(
*[td._clone(recurse=False) for td in self.tensordicts],
stack_dim=self.stack_dim,
)
if self._td_dim_name is not None:
out._td_dim_name = self._td_dim_name
return out
result._td_dim_name = self._td_dim_name
return result

def pin_memory(self) -> T:
for td in self.tensordicts:
Expand Down Expand Up @@ -1382,26 +1389,30 @@ def _apply_nest(
out.names = names
return out

def select(
def _select(
self, *keys: str, inplace: bool = False, strict: bool = False
) -> LazyStackedTensorDict:
# the following implementation keeps the hidden keys in the tensordicts
tensordicts = [
td.select(*keys, inplace=inplace, strict=strict) for td in self.tensordicts
td._select(*keys, inplace=inplace, strict=strict) for td in self.tensordicts
]
if inplace:
return self
return LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim)
result = LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim)
self._maybe_set_shared_attributes(result)
return result

def exclude(self, *keys: str, inplace: bool = False) -> LazyStackedTensorDict:
def _exclude(self, *keys: str, inplace: bool = False) -> LazyStackedTensorDict:
tensordicts = [
tensordict.exclude(*keys, inplace=inplace)
tensordict._exclude(*keys, inplace=inplace)
for tensordict in self.tensordicts
]
if inplace:
self.tensordicts = tensordicts
return self
return LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim)
result = LazyStackedTensorDict(*tensordicts, stack_dim=self.stack_dim)
self._maybe_set_shared_attributes(result)
return result

def __setitem__(self, index: IndexType, value: T) -> T:
if isinstance(index, (tuple, str)):
Expand Down Expand Up @@ -1854,7 +1865,6 @@ def pop(
def share_memory_(self) -> T:
for td in self.tensordicts:
td.share_memory_()
self._is_shared = True
self.lock_()
return self

Expand Down Expand Up @@ -1905,8 +1915,6 @@ def save_metadata(prefix=prefix, self=self):
results = LazyStackedTensorDict.lazy_stack(results, dim=self.stack_dim)
else:
results = self
results._is_memmap = True
results._is_shared = False
results._device = torch.device("cpu")
return results

Expand Down Expand Up @@ -2222,8 +2230,6 @@ def _propagate_unlock(self):
# stack we won't iterate multiple times over it
sub_tds[id(child)] = child._propagate_unlock() + [child]
sub_tds = [item for value in sub_tds.values() for item in value]
self._is_shared = False
self._is_memmap = False
return sub_tds

def __repr__(self):
Expand Down Expand Up @@ -2264,6 +2270,60 @@ def _repr_exclusive_fields(self):

return "\n" + exclusive_key_str

def _view(self, *args, **kwargs):
raise RuntimeError(
"Cannot call `view` on a lazy stacked tensordict. Call `reshape` instead."
)

def _transpose(self, dim0, dim1):
if self._is_vmapped:
raise RuntimeError("cannot call transpose within vmap.")
if dim0 == self.stack_dim:
# we know dim0 and dim1 are sorted so dim1 comes after dim0
# example: shape = [5, 4, 3, 2, 1], stack_dim=1, dim0=1, dim1=4
# resulting shape: [5, 1, 3, 2, 4]
if dim1 == dim0 + 1:
return LazyStackedTensorDict(*self.tensordicts, stack_dim=dim1)
return LazyStackedTensorDict(
*(td.transpose(dim0, dim1 - 1) for td in self.tensordicts),
stack_dim=dim1,
)
elif dim1 == self.stack_dim:
# example: shape = [5, 4, 3, 2, 1], stack_dim=3, dim0=1, dim1=3
# resulting shape: [5, 2, 3, 4, 1]
if dim0 + 1 == dim1:
return LazyStackedTensorDict(*self.tensordicts, stack_dim=dim0)
return LazyStackedTensorDict(
*(td.transpose(dim0 + 1, dim1) for td in self.tensordicts),
stack_dim=dim0,
)
else:
dim0 = dim0 if dim0 < self.stack_dim else dim0 - 1
dim1 = dim1 if dim1 < self.stack_dim else dim1 - 1
return LazyStackedTensorDict(
*(td.transpose(dim0, dim1) for td in self.tensordicts),
stack_dim=self.stack_dim,
)

def _permute(
self,
*args,
**kwargs,
):
raise RuntimeError(
"Cannot call `permute` on a lazy stacked tensordict. Make it dense before calling this method by calling `to_tensordict`."
)

def _squeeze(self, dim=None):
raise RuntimeError(
"Cannot call `squeeze` on a lazy stacked tensordict. Make it dense before calling this method by calling `to_tensordict`."
)

def _unsqueeze(self, dim):
raise RuntimeError(
"Cannot call `unsqueeze` on a lazy stacked tensordict. Make it dense before calling this method by calling `to_tensordict`."
)

lock_ = TensorDictBase.lock_
lock = _renamed_inplace_method(lock_)

Expand All @@ -2278,9 +2338,6 @@ def _repr_exclusive_fields(self):
reshape = TensorDict.reshape
split = TensorDict.split
to_module = TensorDict.to_module
_permute = TensorDict._permute
_transpose = TensorDict._transpose
_view = TensorDict._view


class _CustomOpTensorDict(TensorDictBase):
Expand All @@ -2298,8 +2355,6 @@ def __init__(
inv_op_kwargs: dict | None = None,
batch_size: Sequence[int] | None = None,
) -> None:
self._is_shared = source.is_shared()
self._is_memmap = source.is_memmap()

if not isinstance(source, TensorDictBase):
raise TypeError(
Expand All @@ -2315,6 +2370,15 @@ def __init__(
if batch_size is not None and batch_size != self.batch_size:
raise RuntimeError("batch_size does not match self.batch_size.")

# These attributes should never be set
@property
def _is_shared(self):
return self._source._is_shared

@property
def _is_memmap(self):
return self._source._is_memmap

def is_empty(self) -> bool:
return self._source.is_empty()

Expand Down Expand Up @@ -2430,8 +2494,7 @@ def _set_tuple(self, key, value, *, inplace: bool, validated: bool):
return self._set_str(key[0], value, inplace=inplace, validated=validated)
source = self._source._get_str(key[0], None)
if source is None:
self._source._create_nested_str(key[0])
source = self._source._get_str(key[0], NO_DEFAULT)
source = self._source._create_nested_str(key[0])
nested = type(self)(
source,
custom_op=self.custom_op,
Expand Down Expand Up @@ -2504,29 +2567,19 @@ def keys(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
)

def select(
def _select(
self, *keys: str, inplace: bool = False, strict: bool = True
) -> _CustomOpTensorDict:
if inplace:
self._source.select(*keys, inplace=inplace, strict=strict)
return self
self_copy = copy(self)
self_copy._source = self_copy._source.select(*keys, strict=strict)
return self_copy
raise RuntimeError("Cannot call select inplace on a lazy tensordict.")
return self.to_tensordict()._select(*keys, inplace=False, strict=strict)

def exclude(self, *keys: str, inplace: bool = False) -> T:
def _exclude(self, *keys: str, inplace: bool = False) -> _CustomOpTensorDict:
if inplace:
return super().exclude(*keys, inplace=True)
return TensorDict(
{key: value.clone() for key, value in self.items()},
batch_size=self.batch_size,
device=self.device,
_run_checks=False,
_is_memmap=self.is_memmap(),
_is_shared=self.is_shared(),
).exclude(*keys, inplace=True)
raise RuntimeError("Cannot call exclude inplace on a lazy tensordict.")
return self.to_tensordict()._exclude(*keys, inplace=False)

def clone(self, recurse: bool = True) -> T:
def _clone(self, recurse: bool = True) -> T:
"""Clones the Lazy TensorDict.
Args:
Expand All @@ -2552,7 +2605,7 @@ def is_contiguous(self) -> bool:
def contiguous(self) -> T:
if self.is_contiguous():
return self
return self.to(TensorDict)
return self._fast_apply(lambda x: x.contiguous())

def rename_key_(
self, old_key: str, new_key: str, safe: bool = False
Expand Down Expand Up @@ -2698,7 +2751,6 @@ def _load_memmap(cls, prefix: str, metadata: dict) -> _CustomOpTensorDict:

def share_memory_(self) -> _CustomOpTensorDict:
self._source.share_memory_()
self._is_shared = True
self.lock_()
return self

Expand Down Expand Up @@ -2738,6 +2790,10 @@ def _remove_lock(self, lock_id):
def _propagate_lock(self, lock_ids):
return self._source._propagate_lock(lock_ids)

@erase_cache
def _propagate_unlock(self):
return self._source._propagate_unlock()

lock = _renamed_inplace_method(lock_)
unlock = _renamed_inplace_method(unlock_)

Expand All @@ -2748,6 +2804,35 @@ def __del__(self):
def sorted_keys(self):
return self._source.sorted_keys

def _view(self, *args, **kwargs):
raise RuntimeError(
"Cannot call `view` on a lazy tensordict. Call `reshape` instead."
)

def _transpose(self, dim0, dim1):
raise RuntimeError(
"Cannot call `transpose` on a lazy tensordict. Make it dense before calling this method by calling `to_tensordict`."
)

def _permute(
self,
*args,
**kwargs,
):
raise RuntimeError(
"Cannot call `permute` on a lazy tensordict. Make it dense before calling this method by calling `to_tensordict`."
)

def _squeeze(self, dim=None):
raise RuntimeError(
"Cannot call `squeeze` on a lazy tensordict. Make it dense before calling this method by calling `to_tensordict`."
)

def _unsqueeze(self, dim):
raise RuntimeError(
"Cannot call `unsqueeze` on a lazy tensordict. Make it dense before calling this method by calling `to_tensordict`."
)

__xor__ = TensorDict.__xor__
__or__ = TensorDict.__or__
__eq__ = TensorDict.__eq__
Expand All @@ -2768,9 +2853,6 @@ def sorted_keys(self):
any = TensorDict.any
expand = TensorDict.expand
unbind = TensorDict.unbind
_permute = TensorDict._permute
_transpose = TensorDict._transpose
_view = TensorDict._view
_get_names_idx = TensorDict._get_names_idx


Expand Down
Loading

0 comments on commit 99eff33

Please sign in to comment.