Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/dv3 parallel stochastic #225

Merged
merged 8 commits into from
Mar 4, 2024
114 changes: 109 additions & 5 deletions sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import Tensor, device, nn
from torch.distributions import Distribution, Independent, Normal, TanhTransform, TransformedDistribution
from torch.distributions.utils import probs_to_logits
from torch.nn.modules import Module

from sheeprl.algos.dreamer_v2.agent import WorldModel
from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state
Expand Down Expand Up @@ -461,6 +462,93 @@ def imagination(self, prior: Tensor, recurrent_state: Tensor, actions: Tensor) -
return imagined_prior, recurrent_state


class DecoupledRSSM(RSSM):
"""RSSM model for the model-base Dreamer agent.

Args:
recurrent_model (nn.Module): the recurrent model of the RSSM model described in
[https://arxiv.org/abs/1811.04551](https://arxiv.org/abs/1811.04551).
representation_model (nn.Module): the representation model composed by a
multi-layer perceptron to compute the stochastic part of the latent state.
For more information see [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193).
transition_model (nn.Module): the transition model described in
[https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193).
The model is composed by a multi-layer perceptron to predict the stochastic part of the latent state.
distribution_cfg (Dict[str, Any]): the configs of the distributions.
discrete (int, optional): the size of the Categorical variables.
Defaults to 32.
unimix: (float, optional): the percentage of uniform distribution to inject into the categorical
distribution over states, i.e. given some logits `l` and probabilities `p = softmax(l)`,
then `p = (1 - self.unimix) * p + self.unimix * unif`, where `unif = `1 / self.discrete`.
Defaults to 0.01.
"""

def __init__(
self,
recurrent_model: Module,
representation_model: Module,
transition_model: Module,
distribution_cfg: Dict[str, Any],
discrete: int = 32,
unimix: float = 0.01,
) -> None:
super().__init__(recurrent_model, representation_model, transition_model, distribution_cfg, discrete, unimix)

def dynamic(
self, posterior: Tensor, recurrent_state: Tensor, action: Tensor, is_first: Tensor
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""
Perform one step of the dynamic learning:
Recurrent model: compute the recurrent state from the previous latent space, the action taken by the agent,
i.e., it computes the deterministic state (or ht).
Transition model: predict the prior from the recurrent output.
Representation model: compute the posterior from the recurrent state and from
the embedded observations provided by the environment.
For more information see [https://arxiv.org/abs/1811.04551](https://arxiv.org/abs/1811.04551)
and [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193).

Args:
posterior (Tensor): the stochastic state computed by the representation model (posterior). It is expected
to be of dimension `[stoch_size, self.discrete]`, which by default is `[32, 32]`.
recurrent_state (Tensor): a tuple representing the recurrent state of the recurrent model.
action (Tensor): the action taken by the agent.
embedded_obs (Tensor): the embedded observations provided by the environment.
is_first (Tensor): if this is the first step in the episode.

Returns:
The recurrent state (Tensor): the recurrent state of the recurrent model.
The posterior stochastic state (Tensor): computed by the representation model
The prior stochastic state (Tensor): computed by the transition model
The logits of the posterior state (Tensor): computed by the transition model from the recurrent state.
The logits of the prior state (Tensor): computed by the transition model from the recurrent state.
from the recurrent state and the embbedded observation.
"""
action = (1 - is_first) * action
recurrent_state = (1 - is_first) * recurrent_state + is_first * torch.tanh(torch.zeros_like(recurrent_state))
posterior = posterior.view(*posterior.shape[:-2], -1)
posterior = (1 - is_first) * posterior + is_first * self._transition(recurrent_state, sample_state=False)[
1
].view_as(posterior)
recurrent_state = self.recurrent_model(torch.cat((posterior, action), -1), recurrent_state)
prior_logits, prior = self._transition(recurrent_state)
return recurrent_state, prior, prior_logits

def _representation(self, embedded_obs: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
embedded_obs (Tensor): the embedded real observations provided by the environment.

Returns:
logits (Tensor): the logits of the distribution of the posterior state.
posterior (Tensor): the sampled posterior stochastic state.
"""
logits: Tensor = self.representation_model(embedded_obs)
logits = self._uniform_mix(logits)
return logits, compute_stochastic_state(
logits, discrete=self.discrete, validate_args=self.distribution_cfg.validate_args
)


class PlayerDV3(nn.Module):
"""
The model of the Dreamer_v3 player.
Expand All @@ -486,7 +574,7 @@ class PlayerDV3(nn.Module):
def __init__(
self,
encoder: _FabricModule,
rssm: RSSM,
rssm: RSSM | DecoupledRSSM,
actor: _FabricModule,
actions_dim: Sequence[int],
num_envs: int,
Expand All @@ -495,10 +583,15 @@ def __init__(
device: device = "cpu",
discrete_size: int = 32,
actor_type: str | None = None,
decoupled_rssm: bool = False,
) -> None:
super().__init__()
self.encoder = encoder
self.rssm = RSSM(
if decoupled_rssm:
rssm_cls = DecoupledRSSM
else:
rssm_cls = RSSM
self.rssm = rssm_cls(
recurrent_model=rssm.recurrent_model.module,
representation_model=rssm.representation_model.module,
transition_model=rssm.transition_model.module,
Expand All @@ -515,6 +608,7 @@ def __init__(
self.num_envs = num_envs
self.validate_args = self.actor.distribution_cfg.validate_args
self.actor_type = actor_type
self.decoupled_rssm = decoupled_rssm

@torch.no_grad()
def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
Expand Down Expand Up @@ -580,7 +674,10 @@ def get_greedy_action(
self.recurrent_state = self.rssm.recurrent_model(
torch.cat((self.stochastic_state, self.actions), -1), self.recurrent_state
)
_, self.stochastic_state = self.rssm._representation(self.recurrent_state, embedded_obs)
if self.decoupled_rssm:
_, self.stochastic_state = self.rssm._representation(embedded_obs)
else:
_, self.stochastic_state = self.rssm._representation(self.recurrent_state, embedded_obs)
self.stochastic_state = self.stochastic_state.view(
*self.stochastic_state.shape[:-2], self.stochastic_size * self.discrete_size
)
Expand Down Expand Up @@ -976,8 +1073,11 @@ def build_agent(
**world_model_cfg.recurrent_model,
input_size=int(sum(actions_dim) + stochastic_size),
)
represention_model_input_size = encoder.output_dim
if not cfg.algo.decoupled_rssm:
represention_model_input_size += recurrent_state_size
representation_model = MLP(
input_dims=recurrent_state_size + encoder.cnn_output_dim + encoder.mlp_output_dim,
input_dims=represention_model_input_size,
output_dim=stochastic_size,
hidden_sizes=[world_model_cfg.representation_model.hidden_size],
activation=eval(world_model_cfg.representation_model.dense_act),
Expand All @@ -1004,7 +1104,11 @@ def build_agent(
else None
),
)
rssm = RSSM(
if cfg.algo.decoupled_rssm:
rssm_cls = DecoupledRSSM
else:
rssm_cls = RSSM
rssm = rssm_cls(
recurrent_model=recurrent_model.apply(init_weights),
representation_model=representation_model.apply(init_weights),
transition_model=transition_model.apply(init_weights),
Expand Down
43 changes: 32 additions & 11 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,23 +108,43 @@ def train(
# Dynamic Learning
stoch_state_size = stochastic_size * discrete_size
recurrent_state = torch.zeros(1, batch_size, recurrent_state_size, device=device)
posterior = torch.zeros(1, batch_size, stochastic_size, discrete_size, device=device)
recurrent_states = torch.empty(sequence_length, batch_size, recurrent_state_size, device=device)
priors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device)
posteriors = torch.empty(sequence_length, batch_size, stochastic_size, discrete_size, device=device)
posteriors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device)

# Embed observations from the environment
embedded_obs = world_model.encoder(batch_obs)

for i in range(0, sequence_length):
recurrent_state, posterior, _, posterior_logits, prior_logits = world_model.rssm.dynamic(
posterior, recurrent_state, batch_actions[i : i + 1], embedded_obs[i : i + 1], data["is_first"][i : i + 1]
)
recurrent_states[i] = recurrent_state
priors_logits[i] = prior_logits
posteriors[i] = posterior
posteriors_logits[i] = posterior_logits
if cfg.algo.decoupled_rssm:
posteriors_logits, posteriors = world_model.rssm._representation(embedded_obs)
for i in range(0, sequence_length):
if i == 0:
posterior = torch.zeros_like(posteriors[:1])
else:
posterior = posteriors[i - 1 : i]
recurrent_state, posterior_logits, prior_logits = world_model.rssm.dynamic(
posterior,
recurrent_state,
batch_actions[i : i + 1],
data["is_first"][i : i + 1],
)
recurrent_states[i] = recurrent_state
priors_logits[i] = prior_logits
else:
posterior = torch.zeros(1, batch_size, stochastic_size, discrete_size, device=device)
posteriors = torch.empty(sequence_length, batch_size, stochastic_size, discrete_size, device=device)
posteriors_logits = torch.empty(sequence_length, batch_size, stoch_state_size, device=device)
for i in range(0, sequence_length):
recurrent_state, posterior, _, posterior_logits, prior_logits = world_model.rssm.dynamic(
posterior,
recurrent_state,
batch_actions[i : i + 1],
embedded_obs[i : i + 1],
data["is_first"][i : i + 1],
)
recurrent_states[i] = recurrent_state
priors_logits[i] = prior_logits
posteriors[i] = posterior
posteriors_logits[i] = posterior_logits
latent_states = torch.cat((posteriors.view(*posteriors.shape[:-2], -1), recurrent_states), -1)

# Compute predictions for the observations
Expand Down Expand Up @@ -457,6 +477,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
cfg.algo.world_model.recurrent_model.recurrent_state_size,
fabric.device,
discrete_size=cfg.algo.world_model.discrete_size,
decoupled_rssm=cfg.algo.decoupled_rssm,
)

# Optimizers
Expand Down
1 change: 1 addition & 0 deletions sheeprl/algos/dreamer_v3/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
cfg.algo.world_model.recurrent_model.recurrent_state_size,
fabric.device,
discrete_size=cfg.algo.world_model.discrete_size,
decoupled_rssm=cfg.algo.decoupled_rssm,
)

test(player, fabric, cfg, log_dir, sample_actions=True)
1 change: 1 addition & 0 deletions sheeprl/configs/algo/dreamer_v3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dense_act: torch.nn.SiLU
cnn_act: torch.nn.SiLU
unimix: 0.01
hafner_initialization: True
decoupled_rssm: False

# World model
world_model:
Expand Down
Loading