diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 9b713a3f9..6cedbf41d 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -317,18 +317,21 @@ def _apply_get_post_hook(self, val): def _reset_params(self): parameters = self._param_td param_keys = [] + params = [] buffer_keys = [] + buffers = [] for key, value in parameters.items(True, True): + # flatten key + if isinstance(key, tuple): + key = "_".join(key) if isinstance(value, nn.Parameter): param_keys.append(key) + params.append(value) else: buffer_keys.append(key) - self.__dict__["_parameters"] = ( - parameters.select(*param_keys).flatten_keys("_").to_dict() - ) - self.__dict__["_buffers"] = ( - parameters.select(*buffer_keys).flatten_keys("_").to_dict() - ) + buffers.append(value) + self.__dict__["_parameters"] = dict(zip(param_keys, params)) + self.__dict__["_buffers"] = dict(zip(buffer_keys, buffers)) @classmethod def __torch_function__(