Skip to content

Commit

Permalink
Feature/dv3 parallel stochastic (#225)
Browse files Browse the repository at this point in the history
* Decoupled RSSM for DV3 agent

* Initialize posterior with prior if is_first is True

* Fix PlayerDV3 creation in evaluation

* Fix representation_model

* Fix compute first prior state with a zero posterior
  • Loading branch information
belerico authored Mar 4, 2024
1 parent e8a68f3 commit c79b3eb
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 16 deletions.
114 changes: 109 additions & 5 deletions sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import Tensor, device, nn
from torch.distributions import Distribution, Independent, Normal, TanhTransform, TransformedDistribution
from torch.distributions.utils import probs_to_logits
from torch.nn.modules import Module

from sheeprl.algos.dreamer_v2.agent import WorldModel
from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state
Expand Down Expand Up @@ -461,6 +462,93 @@ def imagination(self, prior: Tensor, recurrent_state: Tensor, actions: Tensor) -
return imagined_prior, recurrent_state


class DecoupledRSSM(RSSM):
"""RSSM model for the model-base Dreamer agent.
Args:
recurrent_model (nn.Module): the recurrent model of the RSSM model described in
[https://arxiv.org/abs/1811.04551](https://arxiv.org/abs/1811.04551).
representation_model (nn.Module): the representation model composed by a
multi-layer perceptron to compute the stochastic part of the latent state.
For more information see [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193).
transition_model (nn.Module): the transition model described in
[https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193).
The model is composed by a multi-layer perceptron to predict the stochastic part of the latent state.
distribution_cfg (Dict[str, Any]): the configs of the distributions.
discrete (int, optional): the size of the Categorical variables.
Defaults to 32.
unimix: (float, optional): the percentage of uniform distribution to inject into the categorical
distribution over states, i.e. given some logits `l` and probabilities `p = softmax(l)`,
then `p = (1 - self.unimix) * p + self.unimix * unif`, where `unif = `1 / self.discrete`.
Defaults to 0.01.
"""

def __init__(
self,
recurrent_model: Module,
representation_model: Module,
transition_model: Module,
distribution_cfg: Dict[str, Any],
discrete: int = 32,
unimix: float = 0.01,
) -> None:
super().__init__(recurrent_model, representation_model, transition_model, distribution_cfg, discrete, unimix)

def dynamic(
self, posterior: Tensor, recurrent_state: Tensor, action: Tensor, is_first: Tensor
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
Perform one step of the dynamic learning:
Recurrent model: compute the recurrent state from the previous latent space, the action taken by the agent,
i.e., it computes the deterministic state (or ht).
Transition model: predict the prior from the recurrent output.
Representation model: compute the posterior from the recurrent state and from
the embedded observations provided by the environment.
For more information see [https://arxiv.org/abs/1811.04551](https://arxiv.org/abs/1811.04551)
and [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193).
Args:
posterior (Tensor): the stochastic state computed by the representation model (posterior). It is expected
to be of dimension `[stoch_size, self.discrete]`, which by default is `[32, 32]`.
recurrent_state (Tensor): a tuple representing the recurrent state of the recurrent model.
action (Tensor): the action taken by the agent.
embedded_obs (Tensor): the embedded observations provided by the environment.
is_first (Tensor): if this is the first step in the episode.
Returns:
The recurrent state (Tensor): the recurrent state of the recurrent model.
The posterior stochastic state (Tensor): computed by the representation model
The prior stochastic state (Tensor): computed by the transition model
The logits of the posterior state (Tensor): computed by the transition model from the recurrent state.
The logits of the prior state (Tensor): computed by the transition model from the recurrent state.
from the recurrent state and the embbedded observation.
"""
action = (1 - is_first) * action
recurrent_state = (1 - is_first) * recurrent_state + is_first * torch.tanh(torch.zeros_like(recurrent_state))
posterior = posterior.view(*posterior.shape[:-2], -1)
posterior = (1 - is_first) * posterior + is_first * self._transition(recurrent_state, sample_state=False)[
1
].view_as(posterior)
recurrent_state = self.recurrent_model(torch.cat((posterior, action), -1), recurrent_state)
prior_logits, prior = self._transition(recurrent_state)
return recurrent_state, prior, prior_logits

def _representation(self, embedded_obs: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
embedded_obs (Tensor): the embedded real observations provided by the environment.
Returns:
logits (Tensor): the logits of the distribution of the posterior state.
posterior (Tensor): the sampled posterior stochastic state.
"""
logits: Tensor = self.representation_model(embedded_obs)
logits = self._uniform_mix(logits)
return logits, compute_stochastic_state(
logits, discrete=self.discrete, validate_args=self.distribution_cfg.validate_args
)


class PlayerDV3(nn.Module):
"""
The model of the Dreamer_v3 player.
Expand All @@ -486,7 +574,7 @@ class PlayerDV3(nn.Module):
def __init__(
self,
encoder: _FabricModule,
rssm: RSSM,
rssm: RSSM | DecoupledRSSM,
actor: _FabricModule,
actions_dim: Sequence[int],
num_envs: int,
Expand All @@ -495,10 +583,15 @@ def __init__(
device: device = "cpu",
discrete_size: int = 32,
actor_type: str | None = None,
decoupled_rssm: bool = False,
) -> None:
super().__init__()
self.encoder = encoder
self.rssm = RSSM(
if decoupled_rssm:
rssm_cls = DecoupledRSSM
else:
rssm_cls = RSSM
self.rssm = rssm_cls(
recurrent_model=rssm.recurrent_model.module,
representation_model=rssm.representation_model.module,
transition_model=rssm.transition_model.module,
Expand All @@ -515,6 +608,7 @@ def __init__(
self.num_envs = num_envs
self.validate_args = self.actor.distribution_cfg.validate_args
self.actor_type = actor_type
self.decoupled_rssm = decoupled_rssm

@torch.no_grad()
def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
Expand Down Expand Up @@ -580,7 +674,10 @@ def get_greedy_action(
self.recurrent_state = self.rssm.recurrent_model(
torch.cat((self.stochastic_state, self.actions), -1), self.recurrent_state
)
_, self.stochastic_state = self.rssm._representation(self.recurrent_state, embedded_obs)
if self.decoupled_rssm:
_, self.stochastic_state = self.rssm._representation(embedded_obs)
else:
_, self.stochastic_state = self.rssm._representation(self.recurrent_state, embedded_obs)
self.stochastic_state = self.stochastic_state.view(
*self.stochastic_state.shape[:-2], self.stochastic_size * self.discrete_size
)
Expand Down Expand Up @@ -976,8 +1073,11 @@ def build_agent(
**world_model_cfg.recurrent_model,
input_size=int(sum(actions_dim) + stochastic_size),
)
represention_model_input_size = encoder.output_dim
if not cfg.algo.decoupled_rssm:
represention_model_input_size += recurrent_state_size
representation_model = MLP(
input_dims=recurrent_state_size + encoder.cnn_output_dim + encoder.mlp_output_dim,
input_dims=represention_model_input_size,
output_dim=stochastic_size,
hidden_sizes=[world_model_cfg.representation_model.hidden_size],
activation=eval(world_model_cfg.representation_model.dense_act),
Expand All @@ -1004,7 +1104,11 @@ def build_agent(
else None
),
)
rssm = RSSM(
if cfg.algo.decoupled_rssm:
rssm_cls = DecoupledRSSM
else:
rssm_cls = RSSM
rssm = rssm_cls(
recurrent_model=recurrent_model.apply(init_weights),
representation_model=representation_model.apply(init_weights),
transition_model=transition_model.apply(init_weights),
Expand Down
43 changes: 32 additions & 11 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,23 +108,43 @@ def train(
# Dynamic Learning
stoch_state_size = stochastic_size * discrete_size
recurrent_state = torch.zeros(1, batch_size, recurrent_state_size, device=device)
posterior = torch.zeros(1, batch_size, stochastic_size, discrete_size, device=device)
recurrent_states = torch.empty(sequence_length, batch_size, recurrent_state_size, device=device)
priors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device)
posteriors = torch.empty(sequence_length, batch_size, stochastic_size, discrete_size, device=device)
posteriors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device)

# Embed observations from the environment
embedded_obs = world_model.encoder(batch_obs)

for i in range(0, sequence_length):
recurrent_state, posterior, _, posterior_logits, prior_logits = world_model.rssm.dynamic(
posterior, recurrent_state, batch_actions[i : i + 1], embedded_obs[i : i + 1], data["is_first"][i : i + 1]
)
recurrent_states[i] = recurrent_state
priors_logits[i] = prior_logits
posteriors[i] = posterior
posteriors_logits[i] = posterior_logits
if cfg.algo.decoupled_rssm:
posteriors_logits, posteriors = world_model.rssm._representation(embedded_obs)
for i in range(0, sequence_length):
if i == 0:
posterior = torch.zeros_like(posteriors[:1])
else:
posterior = posteriors[i - 1 : i]
recurrent_state, posterior_logits, prior_logits = world_model.rssm.dynamic(
posterior,
recurrent_state,
batch_actions[i : i + 1],
data["is_first"][i : i + 1],
)
recurrent_states[i] = recurrent_state
priors_logits[i] = prior_logits
else:
posterior = torch.zeros(1, batch_size, stochastic_size, discrete_size, device=device)
posteriors = torch.empty(sequence_length, batch_size, stochastic_size, discrete_size, device=device)
posteriors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device)
for i in range(0, sequence_length):
recurrent_state, posterior, _, posterior_logits, prior_logits = world_model.rssm.dynamic(
posterior,
recurrent_state,
batch_actions[i : i + 1],
embedded_obs[i : i + 1],
data["is_first"][i : i + 1],
)
recurrent_states[i] = recurrent_state
priors_logits[i] = prior_logits
posteriors[i] = posterior
posteriors_logits[i] = posterior_logits
latent_states = torch.cat((posteriors.view(*posteriors.shape[:-2], -1), recurrent_states), -1)

# Compute predictions for the observations
Expand Down Expand Up @@ -457,6 +477,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
cfg.algo.world_model.recurrent_model.recurrent_state_size,
fabric.device,
discrete_size=cfg.algo.world_model.discrete_size,
decoupled_rssm=cfg.algo.decoupled_rssm,
)

# Optimizers
Expand Down
1 change: 1 addition & 0 deletions sheeprl/algos/dreamer_v3/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
cfg.algo.world_model.recurrent_model.recurrent_state_size,
fabric.device,
discrete_size=cfg.algo.world_model.discrete_size,
decoupled_rssm=cfg.algo.decoupled_rssm,
)

test(player, fabric, cfg, log_dir, sample_actions=True)
1 change: 1 addition & 0 deletions sheeprl/configs/algo/dreamer_v3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dense_act: torch.nn.SiLU
cnn_act: torch.nn.SiLU
unimix: 0.01
hafner_initialization: True
decoupled_rssm: False

# World model
world_model:
Expand Down

0 comments on commit c79b3eb

Please sign in to comment.