Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added optimizations #168

Merged
merged 4 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,12 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
)

# the optimizer and set up it with Fabric
optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters())
optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters(), _convert_="all")

# Create a metric aggregator to log the metrics
aggregator = None
if not MetricAggregator.disabled:
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator).to(device)
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device)

# Local data
rb = ReplayBuffer(cfg.algo.rollout_steps, cfg.env.num_envs, device=device, memmap=cfg.buffer.memmap)
Expand Down
63 changes: 30 additions & 33 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tensordict import TensorDict
from tensordict.tensordict import TensorDictBase
from torch.distributions import Bernoulli, Independent, Normal
from torch.distributions.utils import logits_to_probs
from torch.utils.data import BatchSampler
from torchmetrics import SumMetric

Expand Down Expand Up @@ -106,6 +107,7 @@ def train(
recurrent_state_size = cfg.algo.world_model.recurrent_model.recurrent_state_size
stochastic_size = cfg.algo.world_model.stochastic_size
device = fabric.device
data = {k: data[k] for k in data.keys()}
batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.algo.cnn_keys.encoder}
batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder})

Expand Down Expand Up @@ -229,21 +231,12 @@ def train(
cfg.algo.world_model.continue_scale_factor,
)
fabric.backward(rec_loss)
world_model_grads = None
if cfg.algo.world_model.clip_gradients is not None and cfg.algo.world_model.clip_gradients > 0:
world_model_grads = fabric.clip_gradients(
module=world_model, optimizer=world_optimizer, max_norm=cfg.algo.world_model.clip_gradients
)
world_optimizer.step()
if aggregator and not aggregator.disabled:
aggregator.update("Grads/world_model", world_model_grads.mean().detach())
aggregator.update("Loss/world_model_loss", rec_loss.detach())
aggregator.update("Loss/observation_loss", observation_loss.detach())
aggregator.update("Loss/reward_loss", reward_loss.detach())
aggregator.update("Loss/state_loss", state_loss.detach())
aggregator.update("Loss/continue_loss", continue_loss.detach())
aggregator.update("State/kl", kl.detach())
aggregator.update("State/post_entropy", posteriors_dist.entropy().mean().detach())
aggregator.update("State/prior_entropy", priors_dist.entropy().mean().detach())

# Behaviour Learning
# Unflatten first 2 dimensions of recurrent and posterior states in order
Expand Down Expand Up @@ -282,24 +275,12 @@ def train(
# it is necessary an Independent distribution because
# it is necessary to create (batch_size * sequence_length) independent distributions,
# each producing a sample of size equal to the number of values/rewards
predicted_values = Independent(
Normal(critic(imagined_trajectories), 1, validate_args=validate_args),
1,
validate_args=validate_args,
).mean
predicted_rewards = Independent(
Normal(world_model.reward_model(imagined_trajectories), 1, validate_args=validate_args),
1,
validate_args=validate_args,
).mean
predicted_values = critic(imagined_trajectories)
predicted_rewards = world_model.reward_model(imagined_trajectories)

# predict the probability that the episode will continue in the imagined states
if cfg.algo.world_model.use_continues and world_model.continue_model:
predicted_continues = Independent(
Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
1,
validate_args=validate_args,
).mean
predicted_continues = logits_to_probs(logits=world_model.continue_model(imagined_trajectories), is_binary=True)
else:
predicted_continues = torch.ones_like(predicted_rewards.detach()) * cfg.algo.gamma

Expand Down Expand Up @@ -356,14 +337,12 @@ def train(
# compute the policy loss
policy_loss = actor_loss(discount * lambda_values)
fabric.backward(policy_loss)
actor_grads = None
if cfg.algo.actor.clip_gradients is not None and cfg.algo.actor.clip_gradients > 0:
actor_grads = fabric.clip_gradients(
module=actor, optimizer=actor_optimizer, max_norm=cfg.algo.actor.clip_gradients
)
actor_optimizer.step()
if aggregator and not aggregator.disabled:
aggregator.update("Grads/actor", actor_grads.mean().detach())
aggregator.update("Loss/policy_loss", policy_loss.detach())

# Predict the values distribution only for the first H (horizon) imagined states
# (to match the dimension with the lambda values),
Expand All @@ -383,14 +362,30 @@ def train(
# for the log prob
value_loss = critic_loss(qv, lambda_values.detach(), discount[..., 0])
fabric.backward(value_loss)
critic_grads = None
if cfg.algo.critic.clip_gradients is not None and cfg.algo.critic.clip_gradients > 0:
critic_grads = fabric.clip_gradients(
module=critic, optimizer=critic_optimizer, max_norm=cfg.algo.critic.clip_gradients
)
critic_optimizer.step()

if aggregator and not aggregator.disabled:
aggregator.update("Grads/critic", critic_grads.mean().detach())
aggregator.update("Loss/world_model_loss", rec_loss.detach())
aggregator.update("Loss/observation_loss", observation_loss.detach())
aggregator.update("Loss/reward_loss", reward_loss.detach())
aggregator.update("Loss/state_loss", state_loss.detach())
aggregator.update("Loss/continue_loss", continue_loss.detach())
aggregator.update("State/kl", kl.detach())
aggregator.update("State/post_entropy", posteriors_dist.entropy().mean().detach())
aggregator.update("State/prior_entropy", priors_dist.entropy().mean().detach())
aggregator.update("Loss/policy_loss", policy_loss.detach())
aggregator.update("Loss/value_loss", value_loss.detach())
if world_model_grads:
aggregator.update("Grads/world_model", world_model_grads.mean().detach())
if actor_grads:
aggregator.update("Grads/actor", actor_grads.mean().detach())
if critic_grads:
aggregator.update("Grads/critic", critic_grads.mean().detach())

# Reset everything
actor_optimizer.zero_grad(set_to_none=True)
Expand Down Expand Up @@ -492,9 +487,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)

# Optimizers
world_optimizer = hydra.utils.instantiate(cfg.algo.world_model.optimizer, params=world_model.parameters())
actor_optimizer = hydra.utils.instantiate(cfg.algo.actor.optimizer, params=actor.parameters())
critic_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=critic.parameters())
world_optimizer = hydra.utils.instantiate(
cfg.algo.world_model.optimizer, params=world_model.parameters(), _convert_="all"
)
actor_optimizer = hydra.utils.instantiate(cfg.algo.actor.optimizer, params=actor.parameters(), _convert_="all")
critic_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=critic.parameters(), _convert_="all")
if cfg.checkpoint.resume_from:
world_optimizer.load_state_dict(state["world_optimizer"])
actor_optimizer.load_state_dict(state["actor_optimizer"])
Expand All @@ -509,7 +506,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Metrics
aggregator = None
if not MetricAggregator.disabled:
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator).to(device)
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device)

# Local data
buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 2
Expand Down
7 changes: 4 additions & 3 deletions sheeprl/algos/dreamer_v1/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,14 @@ def reconstruction_loss(
continue_loss (Tensor): the value of the continue loss (0 if it is not computed).
reconstruction_loss (Tensor): the value of the overall reconstruction loss.
"""
device = rewards.device
observation_loss = -sum([qo[k].log_prob(observations[k]).mean() for k in qo.keys()])
reward_loss = -qr.log_prob(rewards).mean()
kl = kl_divergence(posteriors_dist, priors_dist).mean()
state_loss = torch.max(torch.tensor(kl_free_nats, device=device), kl)
continue_loss = torch.tensor(0, device=device)
free_nats = torch.full_like(kl, kl_free_nats)
state_loss = torch.max(kl, free_nats)
if qc is not None and continue_targets is not None:
continue_loss = continue_scale_factor * qc.log_prob(continue_targets)
else:
continue_loss = torch.zeros_like(reward_loss)
reconstruction_loss = kl_regularizer * state_loss + observation_loss + reward_loss + continue_loss
return reconstruction_loss, kl, state_loss, reward_loss, observation_loss, continue_loss
103 changes: 49 additions & 54 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tensordict.tensordict import TensorDictBase
from torch import Tensor
from torch.distributions import Bernoulli, Distribution, Independent, Normal
from torch.distributions.utils import logits_to_probs
from torch.optim import Optimizer
from torch.utils.data import BatchSampler
from torchmetrics import SumMetric
Expand Down Expand Up @@ -120,12 +121,13 @@ def train(
stochastic_size = cfg.algo.world_model.stochastic_size
discrete_size = cfg.algo.world_model.discrete_size
device = fabric.device
data = {k: data[k] for k in data.keys()}
batch_obs = {k: data[k] / 255 - 0.5 for k in cfg.algo.cnn_keys.encoder}
batch_obs.update({k: data[k] for k in cfg.algo.mlp_keys.encoder})

# Given how the environment interaction works, we assume that the first element in a sequence
# is the first one, as if the environment has been reset
data["is_first"][0, :] = torch.tensor([1.0], device=fabric.device).expand_as(data["is_first"][0, :])
data["is_first"][0, :] = torch.ones_like(data["is_first"][0, :])

# Dynamic Learning
stoch_state_size = stochastic_size * discrete_size
Expand Down Expand Up @@ -213,6 +215,7 @@ def train(
validate_args=validate_args,
)
fabric.backward(rec_loss)
world_model_grads = None
if cfg.algo.world_model.clip_gradients is not None and cfg.algo.world_model.clip_gradients > 0:
world_model_grads = fabric.clip_gradients(
module=world_model,
Expand All @@ -221,36 +224,6 @@ def train(
error_if_nonfinite=False,
)
world_optimizer.step()
if aggregator and not aggregator.disabled:
aggregator.update("Grads/world_model", world_model_grads.mean().detach())
aggregator.update("Loss/world_model_loss", rec_loss.detach())
aggregator.update("Loss/observation_loss", observation_loss.detach())
aggregator.update("Loss/reward_loss", reward_loss.detach())
aggregator.update("Loss/state_loss", state_loss.detach())
aggregator.update("Loss/continue_loss", continue_loss.detach())
aggregator.update("State/kl", kl.mean().detach())
aggregator.update(
"State/post_entropy",
Independent(
OneHotCategoricalValidateArgs(logits=posteriors_logits.detach(), validate_args=validate_args),
1,
validate_args=validate_args,
)
.entropy()
.mean()
.detach(),
)
aggregator.update(
"State/prior_entropy",
Independent(
OneHotCategoricalValidateArgs(logits=priors_logits.detach(), validate_args=validate_args),
1,
validate_args=validate_args,
)
.entropy()
.mean()
.detach(),
)

# Behaviour Learning
# (1, batch_size * sequence_length, stochastic_size * discrete_size)
Expand Down Expand Up @@ -308,22 +281,10 @@ def train(
imagined_trajectories[i] = imagined_latent_state

# Predict values and rewards
predicted_target_values = Independent(
Normal(target_critic(imagined_trajectories), 1, validate_args=validate_args),
1,
validate_args=validate_args,
).mode
predicted_rewards = Independent(
Normal(world_model.reward_model(imagined_trajectories), 1, validate_args=validate_args),
1,
validate_args=validate_args,
).mode
predicted_target_values = target_critic(imagined_trajectories)
predicted_rewards = world_model.reward_model(imagined_trajectories)
if cfg.algo.world_model.use_continues and world_model.continue_model:
continues = Independent(
Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
1,
validate_args,
).mean
continues = logits_to_probs(logits=world_model.continue_model(imagined_trajectories), is_binary=True)
true_done = (1 - data["dones"]).reshape(1, -1, 1) * cfg.algo.gamma
continues = torch.cat((true_done, continues[1:]))
else:
Expand Down Expand Up @@ -385,14 +346,12 @@ def train(
entropy = torch.zeros_like(objective)
policy_loss = -torch.mean(discount[:-2].detach() * (objective + entropy.unsqueeze(-1)))
fabric.backward(policy_loss)
actor_grads = None
if cfg.algo.actor.clip_gradients is not None and cfg.algo.actor.clip_gradients > 0:
actor_grads = fabric.clip_gradients(
module=actor, optimizer=actor_optimizer, max_norm=cfg.algo.actor.clip_gradients, error_if_nonfinite=False
)
actor_optimizer.step()
if aggregator and not aggregator.disabled:
aggregator.update("Grads/actor", actor_grads.mean().detach())
aggregator.update("Loss/policy_loss", policy_loss.detach())

# Predict the values distribution only for the first H (horizon)
# imagined states (to match the dimension with the lambda values),
Expand All @@ -413,8 +372,42 @@ def train(
)
critic_optimizer.step()
if aggregator and not aggregator.disabled:
aggregator.update("Grads/critic", critic_grads.mean().detach())
aggregator.update("Loss/world_model_loss", rec_loss.detach())
aggregator.update("Loss/observation_loss", observation_loss.detach())
aggregator.update("Loss/reward_loss", reward_loss.detach())
aggregator.update("Loss/state_loss", state_loss.detach())
aggregator.update("Loss/continue_loss", continue_loss.detach())
aggregator.update("State/kl", kl.mean().detach())
aggregator.update(
"State/post_entropy",
Independent(
OneHotCategoricalValidateArgs(logits=posteriors_logits.detach(), validate_args=validate_args),
1,
validate_args=validate_args,
)
.entropy()
.mean()
.detach(),
)
aggregator.update(
"State/prior_entropy",
Independent(
OneHotCategoricalValidateArgs(logits=priors_logits.detach(), validate_args=validate_args),
1,
validate_args=validate_args,
)
.entropy()
.mean()
.detach(),
)
aggregator.update("Loss/policy_loss", policy_loss.detach())
aggregator.update("Loss/value_loss", value_loss.detach())
if world_model_grads:
aggregator.update("Grads/world_model", world_model_grads.mean().detach())
if actor_grads:
aggregator.update("Grads/actor", actor_grads.mean().detach())
if critic_grads:
aggregator.update("Grads/critic", critic_grads.mean().detach())

# Reset everything
actor_optimizer.zero_grad(set_to_none=True)
Expand Down Expand Up @@ -519,9 +512,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)

# Optimizers
world_optimizer = hydra.utils.instantiate(cfg.algo.world_model.optimizer, params=world_model.parameters())
actor_optimizer = hydra.utils.instantiate(cfg.algo.actor.optimizer, params=actor.parameters())
critic_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=critic.parameters())
world_optimizer = hydra.utils.instantiate(
cfg.algo.world_model.optimizer, params=world_model.parameters(), _convert_="all"
)
actor_optimizer = hydra.utils.instantiate(cfg.algo.actor.optimizer, params=actor.parameters(), _convert_="all")
critic_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=critic.parameters(), _convert_="all")
if cfg.checkpoint.resume_from:
world_optimizer.load_state_dict(state["world_optimizer"])
actor_optimizer.load_state_dict(state["actor_optimizer"])
Expand All @@ -537,7 +532,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Metrics
aggregator = None
if not MetricAggregator.disabled:
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator).to(device)
aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device)

# Local data
buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 2
Expand Down
13 changes: 8 additions & 5 deletions sheeprl/algos/dreamer_v2/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def reconstruction_loss(
continue_loss (Tensor): the value of the continue loss (0 if it is not computed).
reconstruction_loss (Tensor): the value of the overall reconstruction loss.
"""
device = rewards.device
observation_loss = -sum([po[k].log_prob(observations[k]).mean() for k in po.keys()])
reward_loss = -pr.log_prob(rewards).mean()
# KL balancing
Expand Down Expand Up @@ -90,16 +89,20 @@ def reconstruction_loss(
validate_args=validate_args,
),
)
kl_free_nats = torch.tensor([kl_free_nats], device=lhs.device)
if kl_free_avg:
loss_lhs = torch.maximum(lhs.mean(), kl_free_nats)
loss_rhs = torch.maximum(rhs.mean(), kl_free_nats)
lhs = lhs.mean()
rhs = rhs.mean()
free_nats = torch.full_like(lhs, kl_free_nats)
loss_lhs = torch.maximum(lhs, free_nats)
loss_rhs = torch.maximum(rhs, free_nats)
else:
free_nats = torch.full_like(lhs, kl_free_nats)
loss_lhs = torch.maximum(lhs, kl_free_nats).mean()
loss_rhs = torch.maximum(rhs, kl_free_nats).mean()
kl_loss = kl_balancing_alpha * loss_lhs + (1 - kl_balancing_alpha) * loss_rhs
continue_loss = torch.tensor(0, device=device)
if pc is not None and continue_targets is not None:
continue_loss = discount_scale_factor * -pc.log_prob(continue_targets).mean()
else:
continue_loss = torch.zeros_like(reward_loss)
reconstruction_loss = kl_regularizer * kl_loss + observation_loss + reward_loss + continue_loss
return reconstruction_loss, kl, kl_loss, reward_loss, observation_loss, continue_loss
Loading