Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 2, 2024
1 parent 8e271ee commit d15d007
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,16 +797,6 @@ def __setitem__(
subtd = None
for value_key, item in value.items():
if value_key in keys:
if is_non_tensor(item):
dest = self._get_str(value_key, NO_DEFAULT)
if (dest[index] != item).any():
dest_stack = dest.maybe_to_stack()
if dest_stack is not dest:
dest_stack[index] = item
self._set_str(
value_key, dest_stack, validated=True, inplace=False
)
continue
self._set_at_str(
value_key, item, index, validated=False, non_blocking=False
)
Expand Down Expand Up @@ -2093,12 +2083,35 @@ def _set_tuple(
)
return self

_SHARED_INPLACE_ERROR = (
"You're attempting to update a leaf in-place with a shared "
"tensordict, but the new value does not match the previous. "
"If you're using NonTensorData, see the class documentation "
"to see how to properly pre-allocate memory in shared contexts."
)

def _set_at_str(self, key, value, idx, *, validated, non_blocking: bool):
if not validated:
value = self._validate_value(value, check_shape=False)
validated = True
tensor_in = self._get_str(key, NO_DEFAULT)

if is_non_tensor(value) and not (self._is_shared or self._is_memmap):
dest = self._get_str(key, NO_DEFAULT)
is_diff = dest[idx].tolist() != value.tolist()
if is_diff:
dest_val = dest.maybe_to_stack()
dest_val[idx] = value
if dest_val is not dest:
self._set_str(
key,
dest_val,
validated=True,
inplace=False,
ignore_lock=True,
)
return

if isinstance(idx, tuple) and len(idx) and isinstance(idx[0], tuple):
warn(
"Multiple indexing can lead to unexpected behaviours when "
Expand All @@ -2113,12 +2126,7 @@ def _set_at_str(self, key, value, idx, *, validated, non_blocking: bool):
)
if tensor_in is not tensor_out:
if self._is_shared or self._is_memmap:
raise RuntimeError(
"You're attempting to update a leaf in-place with a shared "
"tensordict, but the new value does not match the previous. "
"If you're using NonTensorData, see the class documentation "
"to see how to properly pre-allocate memory in shared contexts."
)
raise RuntimeError(self._SHARED_INPLACE_ERROR)
# this happens only when a NonTensorData becomes a NonTensorStack
# so it is legitimate (there is no in-place modification of a tensor
# that was expected to happen but didn't).
Expand Down

0 comments on commit d15d007

Please sign in to comment.