Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 9, 2023
1 parent 15e6950 commit a235463
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def state_dict(
"""
out = collections.OrderedDict()
source = self.apply(memmap_tensor_as_tensor)
source = self._fast_apply(memmap_tensor_as_tensor)
if flatten:
source = source.flatten_keys(".")
for key, item in source.items():
Expand Down Expand Up @@ -1807,7 +1807,7 @@ def as_tensor(self):
"""
try:
return self.apply(lambda x: x.as_tensor())
return self._fast_apply(lambda x: x.as_tensor())
except AttributeError as err:
raise AttributeError(
f"{self.__class__.__name__} does not have an 'as_tensor' method "
Expand Down Expand Up @@ -2238,7 +2238,7 @@ def flatten(tensor):
)
else:
batch_size = [nelt] + list(self.batch_size[end_dim + 1 :])
out = self.apply(flatten, batch_size=batch_size)
out = self._fast_apply(flatten, batch_size=batch_size)
if self._has_names():
names = [
name
Expand Down Expand Up @@ -2284,7 +2284,7 @@ def unflatten(tensor):
)
else:
batch_size = list(unflattened_size) + list(self.batch_size[1:])
out = self.apply(unflatten, batch_size=batch_size)
out = self._fast_apply(unflatten, batch_size=batch_size)
if self._has_names():
names = copy(self.names)
for _ in range(len(unflattened_size) - 1):
Expand Down Expand Up @@ -2588,7 +2588,7 @@ def detach(self) -> T:
a new tensordict with no tensor requiring gradient.
"""
return self.apply(lambda x: x.detach())
return self._fast_apply(lambda x: x.detach())

def to_h5(
self,
Expand Down Expand Up @@ -3783,7 +3783,7 @@ def fill_(self, key: NestedKey, value: float | bool) -> T:
key = _unravel_key_to_tuple(key)
data = self._get_tuple(key, NO_DEFAULT)
if _is_tensor_collection(data.__class__):
data.apply_(lambda x: x.fill_(value))
data._fast_apply(lambda x: x.fill_(value), inplace=True)
# self._set(key, tensordict, inplace=True)
else:
data = data.fill_(value)
Expand Down Expand Up @@ -3923,27 +3923,27 @@ def is_floating_point(self):

def double(self):
r"""Casts all tensors to ``torch.bool``."""
return self.apply(lambda x: x.double())
return self._fast_apply(lambda x: x.double())

def float(self):
r"""Casts all tensors to ``torch.float``."""
return self.apply(lambda x: x.float())
return self._fast_apply(lambda x: x.float())

def int(self):
r"""Casts all tensors to ``torch.int``."""
return self.apply(lambda x: x.int())
return self._fast_apply(lambda x: x.int())

def bool(self):
r"""Casts all tensors to ``torch.bool``."""
return self.apply(lambda x: x.bool())
return self._fast_apply(lambda x: x.bool())

def half(self):
r"""Casts all tensors to ``torch.half``."""
return self.apply(lambda x: x.half())
return self._fast_apply(lambda x: x.half())

def bfloat16(self):
r"""Casts all tensors to ``torch.bfloat16``."""
return self.apply(lambda x: x.bfloat16())
return self._fast_apply(lambda x: x.bfloat16())

def type(self, dst_type):
r"""Casts all tensors to :attr:`dst_type`.
Expand All @@ -3952,7 +3952,7 @@ def type(self, dst_type):
dst_type (type or string): the desired type
"""
return self.apply(lambda x: x.type(dst_type))
return self._fast_apply(lambda x: x.type(dst_type))


_ACCEPTED_CLASSES = [
Expand Down Expand Up @@ -4367,7 +4367,7 @@ def pin_memory(self) -> T:
def pin_mem(tensor):
return tensor.pin_memory()

return self.apply(pin_mem)
return self._fast_apply(pin_mem)

@overload
def expand(self, *shape: int) -> T:
Expand Down Expand Up @@ -4801,15 +4801,15 @@ def func(tensor, _other):
expand_as_right(condition, tensor), tensor, _other
)

return self.apply(func, other)
return self._fast_apply(func, other)
else:

def func(tensor):
return torch.where(
expand_as_right(condition, tensor), tensor, other
)

return self.apply(func)
return self._fast_apply(func)
else:
if _is_tensor_collection(other.__class__):

Expand All @@ -4818,15 +4818,15 @@ def func(tensor, _other, _out):
expand_as_right(condition, tensor), tensor, _other, out=_out
)

return self.apply(func, other, out)
return self._fast_apply(func, other, out)
else:

def func(tensor, _out):
return torch.where(
expand_as_right(condition, tensor), tensor, other, out=_out
)

return self.apply(func, out)
return self._fast_apply(func, out)

def masked_fill_(self, mask: Tensor, value: float | int | bool) -> T:
for item in self.values():
Expand Down Expand Up @@ -5185,7 +5185,7 @@ def _full_like(td: T, fill_value: float, **kwargs: Any) -> T:

@implements_for_td(torch.zeros_like)
def _zeros_like(td: T, **kwargs: Any) -> T:
td_clone = td.apply(torch.zeros_like)
td_clone = td._fast_apply(torch.zeros_like)
if "dtype" in kwargs:
raise ValueError("Cannot pass dtype to full_like with TensorDict")
if "device" in kwargs:
Expand All @@ -5200,7 +5200,7 @@ def _zeros_like(td: T, **kwargs: Any) -> T:

@implements_for_td(torch.ones_like)
def _ones_like(td: T, **kwargs: Any) -> T:
td_clone = td.apply(lambda x: torch.ones_like(x))
td_clone = td._fast_apply(lambda x: torch.ones_like(x))
if "device" in kwargs:
td_clone = td_clone.to(kwargs.pop("device"))
if len(kwargs):
Expand All @@ -5221,7 +5221,7 @@ def _empty_like(td: T, *args, **kwargs) -> T:
"cloned, preventing empty_like to be called. "
"Consider calling tensordict.to_tensordict() first."
) from err
return tdclone.apply_(lambda x: torch.empty_like(x, *args, **kwargs))
return tdclone._fast_apply_(lambda x: torch.empty_like(x, *args, **kwargs))


@implements_for_td(torch.clone)
Expand Down Expand Up @@ -6142,7 +6142,7 @@ def expand(self, *args: int, inplace: bool = False) -> T:
shape = tuple(args[0])
else:
shape = args
return self.apply(
return self._fast_apply(
lambda x: x.expand((*shape, *x.shape[self.ndim :])), batch_size=shape
)

Expand Down Expand Up @@ -7007,7 +7007,7 @@ def _add_batch_dim(self, *, in_dim, vmap_level):
in_dim = in_dim - 1
stack_dim = td.stack_dim
tds = [
td.apply(
td._fast_apply(
lambda _arg: _add_batch_dim(_arg, in_dim, vmap_level),
batch_size=[b for i, b in enumerate(td.batch_size) if i != in_dim],
names=[name for i, name in enumerate(td.names) if i != in_dim],
Expand Down Expand Up @@ -7240,7 +7240,7 @@ def entry_class(self, key: NestedKey) -> type:
def apply_(self, fn: Callable, *others):
for i, td in enumerate(self.tensordicts):
idx = (slice(None),) * self.stack_dim + (i,)
td.apply_(fn, *[other[idx] for other in others])
td._fast_apply(fn, *[other[idx] for other in others], inplace=True)
return self

def _apply(
Expand Down

0 comments on commit a235463

Please sign in to comment.