diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 8255c1b0..fe92cd9a 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -13,6 +13,7 @@ from torch import Tensor, device, nn from torch.distributions import Distribution, Independent, Normal, TanhTransform, TransformedDistribution from torch.distributions.utils import probs_to_logits +from torch.nn.modules import Module from sheeprl.algos.dreamer_v2.agent import WorldModel from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state @@ -461,6 +462,93 @@ def imagination(self, prior: Tensor, recurrent_state: Tensor, actions: Tensor) - return imagined_prior, recurrent_state +class DecoupledRSSM(RSSM): + """RSSM model for the model-base Dreamer agent. + + Args: + recurrent_model (nn.Module): the recurrent model of the RSSM model described in + [https://arxiv.org/abs/1811.04551](https://arxiv.org/abs/1811.04551). + representation_model (nn.Module): the representation model composed by a + multi-layer perceptron to compute the stochastic part of the latent state. + For more information see [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193). + transition_model (nn.Module): the transition model described in + [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193). + The model is composed by a multi-layer perceptron to predict the stochastic part of the latent state. + distribution_cfg (Dict[str, Any]): the configs of the distributions. + discrete (int, optional): the size of the Categorical variables. + Defaults to 32. + unimix: (float, optional): the percentage of uniform distribution to inject into the categorical + distribution over states, i.e. given some logits `l` and probabilities `p = softmax(l)`, + then `p = (1 - self.unimix) * p + self.unimix * unif`, where `unif = `1 / self.discrete`. + Defaults to 0.01. + """ + + def __init__( + self, + recurrent_model: Module, + representation_model: Module, + transition_model: Module, + distribution_cfg: Dict[str, Any], + discrete: int = 32, + unimix: float = 0.01, + ) -> None: + super().__init__(recurrent_model, representation_model, transition_model, distribution_cfg, discrete, unimix) + + def dynamic( + self, posterior: Tensor, recurrent_state: Tensor, action: Tensor, is_first: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Perform one step of the dynamic learning: + Recurrent model: compute the recurrent state from the previous latent space, the action taken by the agent, + i.e., it computes the deterministic state (or ht). + Transition model: predict the prior from the recurrent output. + Representation model: compute the posterior from the recurrent state and from + the embedded observations provided by the environment. + For more information see [https://arxiv.org/abs/1811.04551](https://arxiv.org/abs/1811.04551) + and [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193). + + Args: + posterior (Tensor): the stochastic state computed by the representation model (posterior). It is expected + to be of dimension `[stoch_size, self.discrete]`, which by default is `[32, 32]`. + recurrent_state (Tensor): a tuple representing the recurrent state of the recurrent model. + action (Tensor): the action taken by the agent. + embedded_obs (Tensor): the embedded observations provided by the environment. + is_first (Tensor): if this is the first step in the episode. + + Returns: + The recurrent state (Tensor): the recurrent state of the recurrent model. + The posterior stochastic state (Tensor): computed by the representation model + The prior stochastic state (Tensor): computed by the transition model + The logits of the posterior state (Tensor): computed by the transition model from the recurrent state. + The logits of the prior state (Tensor): computed by the transition model from the recurrent state. + from the recurrent state and the embbedded observation. + """ + action = (1 - is_first) * action + recurrent_state = (1 - is_first) * recurrent_state + is_first * torch.tanh(torch.zeros_like(recurrent_state)) + posterior = posterior.view(*posterior.shape[:-2], -1) + posterior = (1 - is_first) * posterior + is_first * self._transition(recurrent_state, sample_state=False)[ + 1 + ].view_as(posterior) + recurrent_state = self.recurrent_model(torch.cat((posterior, action), -1), recurrent_state) + prior_logits, prior = self._transition(recurrent_state) + return recurrent_state, prior, prior_logits + + def _representation(self, embedded_obs: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + embedded_obs (Tensor): the embedded real observations provided by the environment. + + Returns: + logits (Tensor): the logits of the distribution of the posterior state. + posterior (Tensor): the sampled posterior stochastic state. + """ + logits: Tensor = self.representation_model(embedded_obs) + logits = self._uniform_mix(logits) + return logits, compute_stochastic_state( + logits, discrete=self.discrete, validate_args=self.distribution_cfg.validate_args + ) + + class PlayerDV3(nn.Module): """ The model of the Dreamer_v3 player. @@ -486,7 +574,7 @@ class PlayerDV3(nn.Module): def __init__( self, encoder: _FabricModule, - rssm: RSSM, + rssm: RSSM | DecoupledRSSM, actor: _FabricModule, actions_dim: Sequence[int], num_envs: int, @@ -495,10 +583,15 @@ def __init__( device: device = "cpu", discrete_size: int = 32, actor_type: str | None = None, + decoupled_rssm: bool = False, ) -> None: super().__init__() self.encoder = encoder - self.rssm = RSSM( + if decoupled_rssm: + rssm_cls = DecoupledRSSM + else: + rssm_cls = RSSM + self.rssm = rssm_cls( recurrent_model=rssm.recurrent_model.module, representation_model=rssm.representation_model.module, transition_model=rssm.transition_model.module, @@ -515,6 +608,7 @@ def __init__( self.num_envs = num_envs self.validate_args = self.actor.distribution_cfg.validate_args self.actor_type = actor_type + self.decoupled_rssm = decoupled_rssm @torch.no_grad() def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: @@ -580,7 +674,10 @@ def get_greedy_action( self.recurrent_state = self.rssm.recurrent_model( torch.cat((self.stochastic_state, self.actions), -1), self.recurrent_state ) - _, self.stochastic_state = self.rssm._representation(self.recurrent_state, embedded_obs) + if self.decoupled_rssm: + _, self.stochastic_state = self.rssm._representation(embedded_obs) + else: + _, self.stochastic_state = self.rssm._representation(self.recurrent_state, embedded_obs) self.stochastic_state = self.stochastic_state.view( *self.stochastic_state.shape[:-2], self.stochastic_size * self.discrete_size ) @@ -976,8 +1073,11 @@ def build_agent( **world_model_cfg.recurrent_model, input_size=int(sum(actions_dim) + stochastic_size), ) + represention_model_input_size = encoder.output_dim + if not cfg.algo.decoupled_rssm: + represention_model_input_size += recurrent_state_size representation_model = MLP( - input_dims=recurrent_state_size + encoder.cnn_output_dim + encoder.mlp_output_dim, + input_dims=represention_model_input_size, output_dim=stochastic_size, hidden_sizes=[world_model_cfg.representation_model.hidden_size], activation=eval(world_model_cfg.representation_model.dense_act), @@ -1004,7 +1104,11 @@ def build_agent( else None ), ) - rssm = RSSM( + if cfg.algo.decoupled_rssm: + rssm_cls = DecoupledRSSM + else: + rssm_cls = RSSM + rssm = rssm_cls( recurrent_model=recurrent_model.apply(init_weights), representation_model=representation_model.apply(init_weights), transition_model=transition_model.apply(init_weights), diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 5fbc1124..94e56f75 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -108,23 +108,43 @@ def train( # Dynamic Learning stoch_state_size = stochastic_size * discrete_size recurrent_state = torch.zeros(1, batch_size, recurrent_state_size, device=device) - posterior = torch.zeros(1, batch_size, stochastic_size, discrete_size, device=device) recurrent_states = torch.empty(sequence_length, batch_size, recurrent_state_size, device=device) priors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device) - posteriors = torch.empty(sequence_length, batch_size, stochastic_size, discrete_size, device=device) - posteriors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device) # Embed observations from the environment embedded_obs = world_model.encoder(batch_obs) - for i in range(0, sequence_length): - recurrent_state, posterior, _, posterior_logits, prior_logits = world_model.rssm.dynamic( - posterior, recurrent_state, batch_actions[i : i + 1], embedded_obs[i : i + 1], data["is_first"][i : i + 1] - ) - recurrent_states[i] = recurrent_state - priors_logits[i] = prior_logits - posteriors[i] = posterior - posteriors_logits[i] = posterior_logits + if cfg.algo.decoupled_rssm: + posteriors_logits, posteriors = world_model.rssm._representation(embedded_obs) + for i in range(0, sequence_length): + if i == 0: + posterior = torch.zeros_like(posteriors[:1]) + else: + posterior = posteriors[i - 1 : i] + recurrent_state, posterior_logits, prior_logits = world_model.rssm.dynamic( + posterior, + recurrent_state, + batch_actions[i : i + 1], + data["is_first"][i : i + 1], + ) + recurrent_states[i] = recurrent_state + priors_logits[i] = prior_logits + else: + posterior = torch.zeros(1, batch_size, stochastic_size, discrete_size, device=device) + posteriors = torch.empty(sequence_length, batch_size, stochastic_size, discrete_size, device=device) + posteriors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device) + for i in range(0, sequence_length): + recurrent_state, posterior, _, posterior_logits, prior_logits = world_model.rssm.dynamic( + posterior, + recurrent_state, + batch_actions[i : i + 1], + embedded_obs[i : i + 1], + data["is_first"][i : i + 1], + ) + recurrent_states[i] = recurrent_state + priors_logits[i] = prior_logits + posteriors[i] = posterior + posteriors_logits[i] = posterior_logits latent_states = torch.cat((posteriors.view(*posteriors.shape[:-2], -1), recurrent_states), -1) # Compute predictions for the observations @@ -457,6 +477,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg.algo.world_model.recurrent_model.recurrent_state_size, fabric.device, discrete_size=cfg.algo.world_model.discrete_size, + decoupled_rssm=cfg.algo.decoupled_rssm, ) # Optimizers diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py index 17b32ddf..7fa239fc 100644 --- a/sheeprl/algos/dreamer_v3/evaluate.py +++ b/sheeprl/algos/dreamer_v3/evaluate.py @@ -63,6 +63,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): cfg.algo.world_model.recurrent_model.recurrent_state_size, fabric.device, discrete_size=cfg.algo.world_model.discrete_size, + decoupled_rssm=cfg.algo.decoupled_rssm, ) test(player, fabric, cfg, log_dir, sample_actions=True) diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index de4b3079..c033f1d7 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -33,6 +33,7 @@ dense_act: torch.nn.SiLU cnn_act: torch.nn.SiLU unimix: 0.01 hafner_initialization: True +decoupled_rssm: False # World model world_model: