Skip to content

Commit

Permalink
[Performance] Faster params and buffer registration in TensorDictPara…
Browse files Browse the repository at this point in the history
…ms (#569)
  • Loading branch information
vmoens authored Nov 23, 2023
1 parent dc4eb6b commit 1a7f43a
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,18 +317,21 @@ def _apply_get_post_hook(self, val):
def _reset_params(self):
parameters = self._param_td
param_keys = []
params = []
buffer_keys = []
buffers = []
for key, value in parameters.items(True, True):
# flatten key
if isinstance(key, tuple):
key = "_".join(key)
if isinstance(value, nn.Parameter):
param_keys.append(key)
params.append(value)
else:
buffer_keys.append(key)
self.__dict__["_parameters"] = (
parameters.select(*param_keys).flatten_keys("_").to_dict()
)
self.__dict__["_buffers"] = (
parameters.select(*buffer_keys).flatten_keys("_").to_dict()
)
buffers.append(value)
self.__dict__["_parameters"] = dict(zip(param_keys, params))
self.__dict__["_buffers"] = dict(zip(buffer_keys, buffers))

@classmethod
def __torch_function__(
Expand Down

0 comments on commit 1a7f43a

Please sign in to comment.