Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 21, 2024
2 parents 85d0c49 + b97d25d commit d5c452d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
2 changes: 1 addition & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def _quick_set(swap_dict, swap_td):
_quick_set(_swap, swap_dest)
return swap_dest
else:
return TensorDict._new_unsafe(_swap, batch_size=[])
return TensorDict._new_unsafe(_swap, batch_size=torch.Size(()))

def __ne__(self, other: object) -> T | bool:
if is_tensorclass(other):
Expand Down
23 changes: 15 additions & 8 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def new_func(self, *args, **kwargs):
if out is self._param_td:
return self
if not isinstance(out, TensorDictParams):
out = TensorDictParams(out, no_convert=True)
out = TensorDictParams(out, no_convert="skip")
out.no_convert = self.no_convert
return out

Expand Down Expand Up @@ -328,11 +328,12 @@ def __init__(
parameters = parameters._param_td
self._param_td = parameters
self.no_convert = no_convert
if not no_convert:
func = _maybe_make_param
else:
func = _maybe_make_param_or_buffer
self._param_td = _apply_leaves(self._param_td, lambda x: func(x))
if no_convert != "skip":
if not no_convert:
func = _maybe_make_param
else:
func = _maybe_make_param_or_buffer
self._param_td = _apply_leaves(self._param_td, lambda x: func(x))
self._lock_content = lock
if lock:
self._param_td.lock_()
Expand All @@ -341,6 +342,12 @@ def __init__(
self._locked_tensordicts = []
self._get_post_hook = []

@classmethod
def _new_unsafe(
cls, parameters: TensorDictBase, *, no_convert=False, lock: bool = False
):
return TensorDictParams(parameters, no_convert="skip", lock=lock)

def __iter__(self):
yield from self._param_td.__iter__()

Expand Down Expand Up @@ -613,7 +620,7 @@ def _clone(self, recurse: bool = True) -> TensorDictBase:
"""
if not recurse:
return TensorDictParams(self._param_td._clone(False), no_convert=True)
return TensorDictParams(self._param_td._clone(False), no_convert="skip")

memo = {}

Expand All @@ -631,7 +638,7 @@ def _clone(tensor, memo=memo):
memo[tensor] = result
return result

return TensorDictParams(self._param_td.apply(_clone), no_convert=True)
return TensorDictParams(self._param_td.apply(_clone), no_convert="skip")

@_fallback
def chunk(self, chunks: int, dim: int = 0) -> tuple[TensorDictBase, ...]: ...
Expand Down

0 comments on commit d5c452d

Please sign in to comment.