Skip to content

Commit

Permalink
Set Distribution.validate_args once at the beginning (#248)
Browse files Browse the repository at this point in the history
* Set Distribution.validate_args once at the beginning

* Move validate_args in run method
  • Loading branch information
belerico authored Mar 30, 2024
1 parent b8afd8a commit 5bc7d78
Show file tree
Hide file tree
Showing 20 changed files with 136 additions and 417 deletions.
13 changes: 3 additions & 10 deletions sheeprl/algos/a2c/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
from torch import Tensor
from torch.distributions import Distribution, Independent, Normal
from torch.distributions import Distribution, Independent, Normal, OneHotCategorical

from sheeprl.algos.ppo.agent import PPOActor
from sheeprl.models.models import MLP
from sheeprl.utils.distribution import OneHotCategoricalValidateArgs
from sheeprl.utils.fabric import get_single_device_fabric


Expand Down Expand Up @@ -127,11 +126,7 @@ def forward(
if self.is_continuous:
mean, log_std = torch.chunk(pre_dist[0], chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(
Normal(mean, std, validate_args=self.distribution_cfg.validate_args),
1,
validate_args=self.distribution_cfg.validate_args,
)
normal = Independent(Normal(mean, std), 1)
if actions is None:
actions = normal.mode if greedy else normal.sample()
else:
Expand All @@ -148,9 +143,7 @@ def forward(
should_append = True
actions: List[Tensor] = []
for i, logits in enumerate(pre_dist):
actions_dist.append(
OneHotCategoricalValidateArgs(logits=logits, validate_args=self.distribution_cfg.validate_args)
)
actions_dist.append(OneHotCategorical(logits=logits))
if should_append:
actions.append(actions_dist[-1].mode if greedy else actions_dist[-1].sample())
actions_logprobs.append(actions_dist[-1].log_prob(actions[i]))
Expand Down
7 changes: 1 addition & 6 deletions sheeprl/algos/dreamer_v1/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def _representation(self, recurrent_state: Tensor, embedded_obs: Tensor) -> Tupl
self.representation_model(torch.cat((recurrent_state, embedded_obs), -1)),
event_shape=1,
min_std=self.min_std,
validate_args=self.distribution_cfg.validate_args,
)
return posterior_mean_std, posterior

Expand All @@ -165,9 +164,7 @@ def _transition(self, recurrent_out: Tensor) -> Tuple[Tuple[Tensor, Tensor], Ten
The prior state (Tensor): the sampled prior state predicted by the transition model.
"""
prior_mean_std = self.transition_model(recurrent_out)
return compute_stochastic_state(
prior_mean_std, event_shape=1, min_std=self.min_std, validate_args=self.distribution_cfg.validate_args
)
return compute_stochastic_state(prior_mean_std, event_shape=1, min_std=self.min_std)

def imagination(self, stochastic_state: Tensor, recurrent_state: Tensor, actions: Tensor) -> Tuple[Tensor, Tensor]:
"""One-step imagination of the next latent state.
Expand Down Expand Up @@ -267,7 +264,6 @@ def __init__(
self.stochastic_size = stochastic_size
self.recurrent_state_size = recurrent_state_size
self.num_envs = num_envs
self.validate_args = self.actor.distribution_cfg.validate_args
self.actor_type = actor_type
self.init_states()

Expand Down Expand Up @@ -333,7 +329,6 @@ def get_actions(
)
_, self.stochastic_state = compute_stochastic_state(
self.representation_model(torch.cat((self.recurrent_state, embedded_obs), -1)),
validate_args=self.validate_args,
)
actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask)
self.actions = torch.cat(actions, -1)
Expand Down
40 changes: 6 additions & 34 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def train(
"""
batch_size = cfg.algo.per_rank_batch_size
sequence_length = cfg.algo.per_rank_sequence_length
validate_args = cfg.distribution.validate_args
recurrent_state_size = cfg.algo.world_model.recurrent_model.recurrent_state_size
stochastic_size = cfg.algo.world_model.stochastic_size
device = fabric.device
Expand Down Expand Up @@ -166,32 +165,17 @@ def train(
# it is necessary an Independent distribution because
# it is necessary to create (batch_size * sequence_length) independent distributions,
# each producing a sample of size observations.shape
qo = {
k: Independent(
Normal(rec_obs, 1, validate_args=validate_args),
len(rec_obs.shape[2:]),
validate_args=validate_args,
)
for k, rec_obs in decoded_information.items()
}
qo = {k: Independent(Normal(rec_obs, 1), len(rec_obs.shape[2:])) for k, rec_obs in decoded_information.items()}

# compute predictions for the rewards
# it is necessary an Independent distribution because
# it is necessary to create (batch_size * sequence_length) independent distributions,
# each producing a sample of size equal to the number of rewards
qr = Independent(
Normal(world_model.reward_model(latent_states), 1, validate_args=validate_args),
1,
validate_args=validate_args,
)
qr = Independent(Normal(world_model.reward_model(latent_states), 1), 1)

# compute predictions for terminal steps, if required
if cfg.algo.world_model.use_continues and world_model.continue_model:
qc = Independent(
Bernoulli(logits=world_model.continue_model(latent_states), validate_args=validate_args),
1,
validate_args=validate_args,
)
qc = Independent(Bernoulli(logits=world_model.continue_model(latent_states)), 1)
continue_targets = (1 - data["dones"]) * cfg.algo.gamma
else:
qc = continue_targets = None
Expand All @@ -200,16 +184,8 @@ def train(
# it is necessary an Independent distribution because
# it is necessary to create (batch_size * sequence_length) independent distributions,
# each producing a sample of size equal to the stochastic size
posteriors_dist = Independent(
Normal(posteriors_mean, posteriors_std, validate_args=validate_args),
1,
validate_args=validate_args,
)
priors_dist = Independent(
Normal(priors_mean, priors_std, validate_args=validate_args),
1,
validate_args=validate_args,
)
posteriors_dist = Independent(Normal(posteriors_mean, posteriors_std), 1)
priors_dist = Independent(Normal(priors_mean, priors_std), 1)

# world model optimization step
world_optimizer.zero_grad(set_to_none=True)
Expand Down Expand Up @@ -351,11 +327,7 @@ def train(
# (to match the dimension with the lambda values),
# it removes the last imagined state in the trajectory
# because it is used only for computing correclty the lambda values
qv = Independent(
Normal(critic(imagined_trajectories.detach())[:-1], 1, validate_args=validate_args),
1,
validate_args=validate_args,
)
qv = Independent(Normal(critic(imagined_trajectories.detach())[:-1], 1), 1)

# critic optimization step
critic_optimizer.zero_grad(set_to_none=True)
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/dreamer_v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def compute_lambda_values(


def compute_stochastic_state(
state_information: Tensor, event_shape: int = 1, min_std: float = 0.1, validate_args: bool = False
state_information: Tensor, event_shape: int = 1, min_std: float = 0.1
) -> Tuple[Tuple[Tensor, Tensor], Tensor]:
"""
Compute the stochastic state from the information of the distribution of the stochastic state.
Expand All @@ -97,12 +97,12 @@ def compute_stochastic_state(
"""
mean, std = torch.chunk(state_information, 2, -1)
std = F.softplus(std) + min_std
state_distribution: Distribution = Normal(mean, std, validate_args=validate_args)
state_distribution: Distribution = Normal(mean, std)
if event_shape:
# it is necessary an Independent distribution because
# it is necessary to create (batch_size * sequence_length) independent distributions,
# each producing a sample of size equal to the stochastic size
state_distribution = Independent(state_distribution, event_shape, validate_args=validate_args)
state_distribution = Independent(state_distribution, event_shape)
stochastic_state = state_distribution.rsample()
return (mean, std), stochastic_state

Expand Down
33 changes: 18 additions & 15 deletions sheeprl/algos/dreamer_v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ The three losses of DreamerV2 are implemented in the `loss.py` file. The *recons
The reconstruction loss is computed as follows:
```python
def reconstruction_loss(
po: Distribution,
observations: Tensor,
po: Dict[str, Distribution],
observations: Dict[str, Tensor],
pr: Distribution,
rewards: Tensor,
priors_logits: Tensor,
Expand All @@ -48,32 +48,35 @@ def reconstruction_loss(
pc: Optional[Distribution] = None,
continue_targets: Optional[Tensor] = None,
discount_scale_factor: float = 1.0,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
device = observations.device
observation_loss = -po.log_prob(observations).mean()
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
observation_loss = -sum([po[k].log_prob(observations[k]).mean() for k in po.keys()])
reward_loss = -pr.log_prob(rewards).mean()
# KL balancing
lhs = kl_divergence(
OneHotCategoricalStraightThroughValidateArgs(logits=posteriors_logits.detach()),
OneHotCategoricalStraightThroughValidateArgs(logits=priors_logits),
lhs = kl = kl_divergence(
Independent(OneHotCategoricalStraightThrough(logits=posteriors_logits.detach()), 1),
Independent(OneHotCategoricalStraightThrough(logits=priors_logits), 1),
)
rhs = kl_divergence(
OneHotCategoricalStraightThroughValidateArgs(logits=posteriors_logits),
OneHotCategoricalStraightThroughValidateArgs(logits=priors_logits.detach()),
Independent(OneHotCategoricalStraightThrough(logits=posteriors_logits), 1),
Independent(OneHotCategoricalStraightThrough(logits=priors_logits.detach()), 1),
)
kl_free_nats = torch.tensor([kl_free_nats], device=lhs.device)
if kl_free_avg:
loss_lhs = torch.maximum(lhs.mean(), kl_free_nats)
loss_rhs = torch.maximum(rhs.mean(), kl_free_nats)
lhs = lhs.mean()
rhs = rhs.mean()
free_nats = torch.full_like(lhs, kl_free_nats)
loss_lhs = torch.maximum(lhs, free_nats)
loss_rhs = torch.maximum(rhs, free_nats)
else:
free_nats = torch.full_like(lhs, kl_free_nats)
loss_lhs = torch.maximum(lhs, kl_free_nats).mean()
loss_rhs = torch.maximum(rhs, kl_free_nats).mean()
kl_loss = kl_balancing_alpha * loss_lhs + (1 - kl_balancing_alpha) * loss_rhs
continue_loss = torch.tensor(0, device=device)
if pc is not None and continue_targets is not None:
continue_loss = discount_scale_factor * -pc.log_prob(continue_targets).mean()
else:
continue_loss = torch.zeros_like(reward_loss)
reconstruction_loss = kl_regularizer * kl_loss + observation_loss + reward_loss + continue_loss
return reconstruction_loss, kl_loss, reward_loss, observation_loss, continue_loss
return reconstruction_loss, kl, kl_loss, reward_loss, observation_loss, continue_loss
```
Here it is necessary to define some hyper-parameters, such as *(i)* the `kl_free_nats`, which is the minimum value of the *KL loss* (default to 0); or *(ii)* the `kl_regularizer` parameter to scale the *KL loss*; *(iii)* wheter to compute or not the *continue loss*; *(iv)* `discount_scale_factor`, the parameter to scale the *continue loss*.

Expand Down
69 changes: 23 additions & 46 deletions sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@
from lightning.fabric import Fabric
from lightning.fabric.wrappers import _FabricModule
from torch import Tensor, nn
from torch.distributions import Distribution, Independent, Normal, TanhTransform, TransformedDistribution
from torch.distributions import (
Distribution,
Independent,
Normal,
OneHotCategorical,
OneHotCategoricalStraightThrough,
TanhTransform,
TransformedDistribution,
)

from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state, init_weights
from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell, MultiDecoder, MultiEncoder
from sheeprl.utils.distribution import (
OneHotCategoricalStraightThroughValidateArgs,
OneHotCategoricalValidateArgs,
TruncatedNormal,
)
from sheeprl.utils.distribution import TruncatedNormal
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.model import LayerNormChannelLast, ModuleType, cnn_forward

Expand Down Expand Up @@ -374,9 +378,7 @@ def _representation(self, recurrent_state: Tensor, embedded_obs: Tensor) -> Tupl
posterior (Tensor): the sampled posterior stochastic state.
"""
logits = self.representation_model(torch.cat((recurrent_state, embedded_obs), -1))
return logits, compute_stochastic_state(
logits, discrete=self.discrete, validate_args=self.distribution_cfg.validate_args
)
return logits, compute_stochastic_state(logits, discrete=self.discrete)

def _transition(self, recurrent_out: Tensor) -> Tuple[Tensor, Tensor]:
"""
Expand All @@ -388,9 +390,7 @@ def _transition(self, recurrent_out: Tensor) -> Tuple[Tensor, Tensor]:
prior (Tensor): the sampled prior stochastic state.
"""
logits = self.transition_model(recurrent_out)
return logits, compute_stochastic_state(
logits, discrete=self.discrete, validate_args=self.distribution_cfg.validate_args
)
return logits, compute_stochastic_state(logits, discrete=self.discrete)

def imagination(self, prior: Tensor, recurrent_state: Tensor, actions: Tensor) -> Tuple[Tensor, Tensor]:
"""
Expand Down Expand Up @@ -525,21 +525,15 @@ def forward(
if self.distribution == "tanh_normal":
mean = 5 * torch.tanh(mean / 5)
std = F.softplus(std + self.init_std) + self.min_std
actions_dist = Normal(mean, std, validate_args=self.distribution_cfg.validate_args)
actions_dist = Independent(
TransformedDistribution(
actions_dist, TanhTransform(), validate_args=self.distribution_cfg.validate_args
),
1,
validate_args=self.distribution_cfg.validate_args,
)
actions_dist = Normal(mean, std)
actions_dist = Independent(TransformedDistribution(actions_dist, TanhTransform()), 1)
elif self.distribution == "normal":
actions_dist = Normal(mean, std, validate_args=self.distribution_cfg.validate_args)
actions_dist = Independent(actions_dist, 1, validate_args=self.distribution_cfg.validate_args)
actions_dist = Normal(mean, std)
actions_dist = Independent(actions_dist, 1)
elif self.distribution == "trunc_normal":
std = 2 * torch.sigmoid((std + self.init_std) / 2) + self.min_std
dist = TruncatedNormal(torch.tanh(mean), std, -1, 1, validate_args=self.distribution_cfg.validate_args)
actions_dist = Independent(dist, 1, validate_args=self.distribution_cfg.validate_args)
dist = TruncatedNormal(torch.tanh(mean), std, -1, 1)
actions_dist = Independent(dist, 1)
if sample_actions:
actions = actions_dist.rsample()
else:
Expand All @@ -552,11 +546,7 @@ def forward(
actions_dist: List[Distribution] = []
actions: List[Tensor] = []
for logits in pre_dist:
actions_dist.append(
OneHotCategoricalStraightThroughValidateArgs(
logits=logits, validate_args=self.distribution_cfg.validate_args
)
)
actions_dist.append(OneHotCategoricalStraightThrough(logits=logits))
if sample_actions:
actions.append(actions_dist[-1].rsample())
else:
Expand All @@ -575,11 +565,7 @@ def add_exploration_noise(
else:
expl_actions = []
for act in actions:
sample = (
OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False)
.sample()
.to(act.device)
)
sample = OneHotCategorical(logits=torch.zeros_like(act)).sample().to(act.device)
expl_actions.append(
torch.where(torch.rand(act.shape[:1], device=act.device) < expl_amount, sample, act)
)
Expand Down Expand Up @@ -663,11 +649,7 @@ def forward(
logits[t, b][torch.logical_not(mask["mask_equip_place"][t, b])] = -torch.inf
elif sampled_action == 18: # Destroy action
logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf
actions_dist.append(
OneHotCategoricalStraightThroughValidateArgs(
logits=logits, validate_args=self.distribution_cfg.validate_args
)
)
actions_dist.append(OneHotCategoricalStraightThrough(logits=logits))
if sample_actions:
actions.append(actions_dist[-1].rsample())
else:
Expand Down Expand Up @@ -704,9 +686,7 @@ def add_exploration_noise(
logits[t, b][torch.logical_not(mask["mask_equip_place"][t, b])] = -torch.inf
elif sampled_action == 18: # Destroy action
logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf
sample = (
OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False).sample().to(act.device)
)
sample = OneHotCategorical(logits=torch.zeros_like(act)).sample().to(act.device)
expl_amount = self._get_expl_amount(step)
# If the action[0] was changed, and now it is critical, then we force to change also the other 2 actions
# to satisfy the constraints of the environment
Expand Down Expand Up @@ -805,7 +785,6 @@ def __init__(
self.discrete_size = discrete_size
self.recurrent_state_size = recurrent_state_size
self.num_envs = num_envs
self.validate_args = self.actor.distribution_cfg.validate_args
self.actor_type = actor_type

def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
Expand Down Expand Up @@ -851,9 +830,7 @@ def get_actions(
torch.cat((self.stochastic_state, self.actions), -1), self.recurrent_state
)
posterior_logits = self.representation_model(torch.cat((self.recurrent_state, embedded_obs), -1))
stochastic_state = compute_stochastic_state(
posterior_logits, discrete=self.discrete_size, validate_args=self.validate_args
)
stochastic_state = compute_stochastic_state(posterior_logits, discrete=self.discrete_size)
self.stochastic_state = stochastic_state.view(
*stochastic_state.shape[:-2], self.stochastic_size * self.discrete_size
)
Expand Down
Loading

0 comments on commit 5bc7d78

Please sign in to comment.