Skip to content

Commit

Permalink
feat: added optimizations (#168)
Browse files Browse the repository at this point in the history
* feat: added optimizations

* fix: avoid creation of distributions when not necessary

* fix: dreamer_v2 loss

* fix: added _conver_='all' to optimizer creation
  • Loading branch information
michele-milesi authored Dec 13, 2023
1 parent e799f55 commit 796cd96
Show file tree
Hide file tree
Showing 22 changed files with 447 additions and 408 deletions.
4 changes: 2 additions & 2 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,12 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
)

# the optimizer and set up it with Fabric
optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters())
optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters(), _convert_="all")

# Create a metric aggregator to log the metrics
aggregator = None
if not MetricAggregator.disabled:
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator).to(device)
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device)

# Local data
rb = ReplayBuffer(cfg.algo.rollout_steps, cfg.env.num_envs, device=device, memmap=cfg.buffer.memmap)
Expand Down
63 changes: 30 additions & 33 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tensordict import TensorDict
from tensordict.tensordict import TensorDictBase
from torch.distributions import Bernoulli, Independent, Normal
from torch.distributions.utils import logits_to_probs
from torch.utils.data import BatchSampler
from torchmetrics import SumMetric

Expand Down Expand Up @@ -106,6 +107,7 @@ def train(
recurrent_state_size = cfg.algo.world_model.recurrent_model.recurrent_state_size
stochastic_size = cfg.algo.world_model.stochastic_size
device = fabric.device
data = {k: data[k] for k in data.keys()}
batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.algo.cnn_keys.encoder}
batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder})

Expand Down Expand Up @@ -229,21 +231,12 @@ def train(
cfg.algo.world_model.continue_scale_factor,
)
fabric.backward(rec_loss)
world_model_grads = None
if cfg.algo.world_model.clip_gradients is not None and cfg.algo.world_model.clip_gradients > 0:
world_model_grads = fabric.clip_gradients(
module=world_model, optimizer=world_optimizer, max_norm=cfg.algo.world_model.clip_gradients
)
world_optimizer.step()
if aggregator and not aggregator.disabled:
aggregator.update("Grads/world_model", world_model_grads.mean().detach())
aggregator.update("Loss/world_model_loss", rec_loss.detach())
aggregator.update("Loss/observation_loss", observation_loss.detach())
aggregator.update("Loss/reward_loss", reward_loss.detach())
aggregator.update("Loss/state_loss", state_loss.detach())
aggregator.update("Loss/continue_loss", continue_loss.detach())
aggregator.update("State/kl", kl.detach())
aggregator.update("State/post_entropy", posteriors_dist.entropy().mean().detach())
aggregator.update("State/prior_entropy", priors_dist.entropy().mean().detach())

# Behaviour Learning
# Unflatten first 2 dimensions of recurrent and posterior states in order
Expand Down Expand Up @@ -282,24 +275,12 @@ def train(
# it is necessary an Independent distribution because
# it is necessary to create (batch_size * sequence_length) independent distributions,
# each producing a sample of size equal to the number of values/rewards
predicted_values = Independent(
Normal(critic(imagined_trajectories), 1, validate_args=validate_args),
1,
validate_args=validate_args,
).mean
predicted_rewards = Independent(
Normal(world_model.reward_model(imagined_trajectories), 1, validate_args=validate_args),
1,
validate_args=validate_args,
).mean
predicted_values = critic(imagined_trajectories)
predicted_rewards = world_model.reward_model(imagined_trajectories)

# predict the probability that the episode will continue in the imagined states
if cfg.algo.world_model.use_continues and world_model.continue_model:
predicted_continues = Independent(
Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
1,
validate_args=validate_args,
).mean
predicted_continues = logits_to_probs(logits=world_model.continue_model(imagined_trajectories), is_binary=True)
else:
predicted_continues = torch.ones_like(predicted_rewards.detach()) * cfg.algo.gamma

Expand Down Expand Up @@ -356,14 +337,12 @@ def train(
# compute the policy loss
policy_loss = actor_loss(discount * lambda_values)
fabric.backward(policy_loss)
actor_grads = None
if cfg.algo.actor.clip_gradients is not None and cfg.algo.actor.clip_gradients > 0:
actor_grads = fabric.clip_gradients(
module=actor, optimizer=actor_optimizer, max_norm=cfg.algo.actor.clip_gradients
)
actor_optimizer.step()
if aggregator and not aggregator.disabled:
aggregator.update("Grads/actor", actor_grads.mean().detach())
aggregator.update("Loss/policy_loss", policy_loss.detach())

# Predict the values distribution only for the first H (horizon) imagined states
# (to match the dimension with the lambda values),
Expand All @@ -383,14 +362,30 @@ def train(
# for the log prob
value_loss = critic_loss(qv, lambda_values.detach(), discount[..., 0])
fabric.backward(value_loss)
critic_grads = None
if cfg.algo.critic.clip_gradients is not None and cfg.algo.critic.clip_gradients > 0:
critic_grads = fabric.clip_gradients(
module=critic, optimizer=critic_optimizer, max_norm=cfg.algo.critic.clip_gradients
)
critic_optimizer.step()

if aggregator and not aggregator.disabled:
aggregator.update("Grads/critic", critic_grads.mean().detach())
aggregator.update("Loss/world_model_loss", rec_loss.detach())
aggregator.update("Loss/observation_loss", observation_loss.detach())
aggregator.update("Loss/reward_loss", reward_loss.detach())
aggregator.update("Loss/state_loss", state_loss.detach())
aggregator.update("Loss/continue_loss", continue_loss.detach())
aggregator.update("State/kl", kl.detach())
aggregator.update("State/post_entropy", posteriors_dist.entropy().mean().detach())
aggregator.update("State/prior_entropy", priors_dist.entropy().mean().detach())
aggregator.update("Loss/policy_loss", policy_loss.detach())
aggregator.update("Loss/value_loss", value_loss.detach())
if world_model_grads:
aggregator.update("Grads/world_model", world_model_grads.mean().detach())
if actor_grads:
aggregator.update("Grads/actor", actor_grads.mean().detach())
if critic_grads:
aggregator.update("Grads/critic", critic_grads.mean().detach())

# Reset everything
actor_optimizer.zero_grad(set_to_none=True)
Expand Down Expand Up @@ -492,9 +487,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)

# Optimizers
world_optimizer = hydra.utils.instantiate(cfg.algo.world_model.optimizer, params=world_model.parameters())
actor_optimizer = hydra.utils.instantiate(cfg.algo.actor.optimizer, params=actor.parameters())
critic_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=critic.parameters())
world_optimizer = hydra.utils.instantiate(
cfg.algo.world_model.optimizer, params=world_model.parameters(), _convert_="all"
)
actor_optimizer = hydra.utils.instantiate(cfg.algo.actor.optimizer, params=actor.parameters(), _convert_="all")
critic_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=critic.parameters(), _convert_="all")
if cfg.checkpoint.resume_from:
world_optimizer.load_state_dict(state["world_optimizer"])
actor_optimizer.load_state_dict(state["actor_optimizer"])
Expand All @@ -509,7 +506,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Metrics
aggregator = None
if not MetricAggregator.disabled:
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator).to(device)
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device)

# Local data
buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 2
Expand Down
7 changes: 4 additions & 3 deletions sheeprl/algos/dreamer_v1/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,14 @@ def reconstruction_loss(
continue_loss (Tensor): the value of the continue loss (0 if it is not computed).
reconstruction_loss (Tensor): the value of the overall reconstruction loss.
"""
device = rewards.device
observation_loss = -sum([qo[k].log_prob(observations[k]).mean() for k in qo.keys()])
reward_loss = -qr.log_prob(rewards).mean()
kl = kl_divergence(posteriors_dist, priors_dist).mean()
state_loss = torch.max(torch.tensor(kl_free_nats, device=device), kl)
continue_loss = torch.tensor(0, device=device)
free_nats = torch.full_like(kl, kl_free_nats)
state_loss = torch.max(kl, free_nats)
if qc is not None and continue_targets is not None:
continue_loss = continue_scale_factor * qc.log_prob(continue_targets)
else:
continue_loss = torch.zeros_like(reward_loss)
reconstruction_loss = kl_regularizer * state_loss + observation_loss + reward_loss + continue_loss
return reconstruction_loss, kl, state_loss, reward_loss, observation_loss, continue_loss
103 changes: 49 additions & 54 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tensordict.tensordict import TensorDictBase
from torch import Tensor
from torch.distributions import Bernoulli, Distribution, Independent, Normal
from torch.distributions.utils import logits_to_probs
from torch.optim import Optimizer
from torch.utils.data import BatchSampler
from torchmetrics import SumMetric
Expand Down Expand Up @@ -120,12 +121,13 @@ def train(
stochastic_size = cfg.algo.world_model.stochastic_size
discrete_size = cfg.algo.world_model.discrete_size
device = fabric.device
data = {k: data[k] for k in data.keys()}
batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.algo.cnn_keys.encoder}
batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder})

# Given how the environment interaction works, we assume that the first element in a sequence
# is the first one, as if the environment has been reset
data["is_first"][0, :] = torch.tensor([1.0], device=fabric.device).expand_as(data["is_first"][0, :])
data["is_first"][0, :] = torch.ones_like(data["is_first"][0, :])

# Dynamic Learning
stoch_state_size = stochastic_size * discrete_size
Expand Down Expand Up @@ -213,6 +215,7 @@ def train(
validate_args=validate_args,
)
fabric.backward(rec_loss)
world_model_grads = None
if cfg.algo.world_model.clip_gradients is not None and cfg.algo.world_model.clip_gradients > 0:
world_model_grads = fabric.clip_gradients(
module=world_model,
Expand All @@ -221,36 +224,6 @@ def train(
error_if_nonfinite=False,
)
world_optimizer.step()
if aggregator and not aggregator.disabled:
aggregator.update("Grads/world_model", world_model_grads.mean().detach())
aggregator.update("Loss/world_model_loss", rec_loss.detach())
aggregator.update("Loss/observation_loss", observation_loss.detach())
aggregator.update("Loss/reward_loss", reward_loss.detach())
aggregator.update("Loss/state_loss", state_loss.detach())
aggregator.update("Loss/continue_loss", continue_loss.detach())
aggregator.update("State/kl", kl.mean().detach())
aggregator.update(
"State/post_entropy",
Independent(
OneHotCategoricalValidateArgs(logits=posteriors_logits.detach(), validate_args=validate_args),
1,
validate_args=validate_args,
)
.entropy()
.mean()
.detach(),
)
aggregator.update(
"State/prior_entropy",
Independent(
OneHotCategoricalValidateArgs(logits=priors_logits.detach(), validate_args=validate_args),
1,
validate_args=validate_args,
)
.entropy()
.mean()
.detach(),
)

# Behaviour Learning
# (1, batch_size * sequence_length, stochastic_size * discrete_size)
Expand Down Expand Up @@ -308,22 +281,10 @@ def train(
imagined_trajectories[i] = imagined_latent_state

# Predict values and rewards
predicted_target_values = Independent(
Normal(target_critic(imagined_trajectories), 1, validate_args=validate_args),
1,
validate_args=validate_args,
).mode
predicted_rewards = Independent(
Normal(world_model.reward_model(imagined_trajectories), 1, validate_args=validate_args),
1,
validate_args=validate_args,
).mode
predicted_target_values = target_critic(imagined_trajectories)
predicted_rewards = world_model.reward_model(imagined_trajectories)
if cfg.algo.world_model.use_continues and world_model.continue_model:
continues = Independent(
Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
1,
validate_args,
).mean
continues = logits_to_probs(logits=world_model.continue_model(imagined_trajectories), is_binary=True)
true_done = (1 - data["dones"]).reshape(1, -1, 1) * cfg.algo.gamma
continues = torch.cat((true_done, continues[1:]))
else:
Expand Down Expand Up @@ -385,14 +346,12 @@ def train(
entropy = torch.zeros_like(objective)
policy_loss = -torch.mean(discount[:-2].detach() * (objective + entropy.unsqueeze(-1)))
fabric.backward(policy_loss)
actor_grads = None
if cfg.algo.actor.clip_gradients is not None and cfg.algo.actor.clip_gradients > 0:
actor_grads = fabric.clip_gradients(
module=actor, optimizer=actor_optimizer, max_norm=cfg.algo.actor.clip_gradients, error_if_nonfinite=False
)
actor_optimizer.step()
if aggregator and not aggregator.disabled:
aggregator.update("Grads/actor", actor_grads.mean().detach())
aggregator.update("Loss/policy_loss", policy_loss.detach())

# Predict the values distribution only for the first H (horizon)
# imagined states (to match the dimension with the lambda values),
Expand All @@ -413,8 +372,42 @@ def train(
)
critic_optimizer.step()
if aggregator and not aggregator.disabled:
aggregator.update("Grads/critic", critic_grads.mean().detach())
aggregator.update("Loss/world_model_loss", rec_loss.detach())
aggregator.update("Loss/observation_loss", observation_loss.detach())
aggregator.update("Loss/reward_loss", reward_loss.detach())
aggregator.update("Loss/state_loss", state_loss.detach())
aggregator.update("Loss/continue_loss", continue_loss.detach())
aggregator.update("State/kl", kl.mean().detach())
aggregator.update(
"State/post_entropy",
Independent(
OneHotCategoricalValidateArgs(logits=posteriors_logits.detach(), validate_args=validate_args),
1,
validate_args=validate_args,
)
.entropy()
.mean()
.detach(),
)
aggregator.update(
"State/prior_entropy",
Independent(
OneHotCategoricalValidateArgs(logits=priors_logits.detach(), validate_args=validate_args),
1,
validate_args=validate_args,
)
.entropy()
.mean()
.detach(),
)
aggregator.update("Loss/policy_loss", policy_loss.detach())
aggregator.update("Loss/value_loss", value_loss.detach())
if world_model_grads:
aggregator.update("Grads/world_model", world_model_grads.mean().detach())
if actor_grads:
aggregator.update("Grads/actor", actor_grads.mean().detach())
if critic_grads:
aggregator.update("Grads/critic", critic_grads.mean().detach())

# Reset everything
actor_optimizer.zero_grad(set_to_none=True)
Expand Down Expand Up @@ -519,9 +512,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)

# Optimizers
world_optimizer = hydra.utils.instantiate(cfg.algo.world_model.optimizer, params=world_model.parameters())
actor_optimizer = hydra.utils.instantiate(cfg.algo.actor.optimizer, params=actor.parameters())
critic_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=critic.parameters())
world_optimizer = hydra.utils.instantiate(
cfg.algo.world_model.optimizer, params=world_model.parameters(), _convert_="all"
)
actor_optimizer = hydra.utils.instantiate(cfg.algo.actor.optimizer, params=actor.parameters(), _convert_="all")
critic_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=critic.parameters(), _convert_="all")
if cfg.checkpoint.resume_from:
world_optimizer.load_state_dict(state["world_optimizer"])
actor_optimizer.load_state_dict(state["actor_optimizer"])
Expand All @@ -537,7 +532,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Metrics
aggregator = None
if not MetricAggregator.disabled:
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator).to(device)
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device)

# Local data
buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 2
Expand Down
13 changes: 8 additions & 5 deletions sheeprl/algos/dreamer_v2/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def reconstruction_loss(
continue_loss (Tensor): the value of the continue loss (0 if it is not computed).
reconstruction_loss (Tensor): the value of the overall reconstruction loss.
"""
device = rewards.device
observation_loss = -sum([po[k].log_prob(observations[k]).mean() for k in po.keys()])
reward_loss = -pr.log_prob(rewards).mean()
# KL balancing
Expand Down Expand Up @@ -90,16 +89,20 @@ def reconstruction_loss(
validate_args=validate_args,
),
)
kl_free_nats = torch.tensor([kl_free_nats], device=lhs.device)
if kl_free_avg:
loss_lhs = torch.maximum(lhs.mean(), kl_free_nats)
loss_rhs = torch.maximum(rhs.mean(), kl_free_nats)
lhs = lhs.mean()
rhs = rhs.mean()
free_nats = torch.full_like(lhs, kl_free_nats)
loss_lhs = torch.maximum(lhs, free_nats)
loss_rhs = torch.maximum(rhs, free_nats)
else:
free_nats = torch.full_like(lhs, kl_free_nats)
loss_lhs = torch.maximum(lhs, kl_free_nats).mean()
loss_rhs = torch.maximum(rhs, kl_free_nats).mean()
kl_loss = kl_balancing_alpha * loss_lhs + (1 - kl_balancing_alpha) * loss_rhs
continue_loss = torch.tensor(0, device=device)
if pc is not None and continue_targets is not None:
continue_loss = discount_scale_factor * -pc.log_prob(continue_targets).mean()
else:
continue_loss = torch.zeros_like(reward_loss)
reconstruction_loss = kl_regularizer * kl_loss + observation_loss + reward_loss + continue_loss
return reconstruction_loss, kl, kl_loss, reward_loss, observation_loss, continue_loss
Loading

0 comments on commit 796cd96

Please sign in to comment.