Skip to content

Commit

Permalink
Feature/replay ratio (#247)
Browse files Browse the repository at this point in the history
* Decoupled RSSM for DV3 agent

* Initialize posterior with prior if is_first is True

* Fix PlayerDV3 creation in evaluation

* Fix representation_model

* Fix compute first prior state with a zero posterior

* DV3 replay ratio conversion

* Removed expl parameters dependent on old per_Rank_gradient_steps

* feat: update repeats computation

* feat: update learning starts in config

* fix: remove files

* feat: update repeats

* feat: added replay ratio and update exploration

* Fix exploration actions computation on DV1

* Fix naming

* Add replay-ratio to SAC

* feat: added replay ratio to p2e algos

* feat: update configs and utils of p2e algos

* Add replay-ratio to SAC-AE

* Add DrOQ replay ratio

* Fix tests

* Fix mispelled

* Fix wrong attribute accesing

* FIx naming and configs

* Ratio: account for pretrain steps

* Fix dreamer-vq actor naming

* feat: added ratio state to checkpoint in sac decoupled

* feat: added typing in Ratio class

* Move ratio.py to examples

* Log dreamer-v1 exploration amount

* Fix DV1 log expl amount

* Fix DV2 replay ratios

---------

Co-authored-by: Michele Milesi <[email protected]>
  • Loading branch information
belerico and michele-milesi authored Mar 29, 2024
1 parent df3734a commit 3dc227b
Show file tree
Hide file tree
Showing 60 changed files with 976 additions and 1,064 deletions.
74 changes: 74 additions & 0 deletions examples/ratio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import warnings
from typing import Any, Dict, Mapping


class Ratio:
"""Directly taken from Hafner et al. (2023) implementation:
https://github.com/danijar/dreamerv3/blob/8fa35f83eee1ce7e10f3dee0b766587d0a713a60/dreamerv3/embodied/core/when.py#L26
"""

def __init__(self, ratio: float, pretrain_steps: int = 0):
if pretrain_steps < 0:
raise ValueError(f"'pretrain_steps' must be non-negative, got {pretrain_steps}")
if ratio < 0:
raise ValueError(f"'ratio' must be non-negative, got {ratio}")
self._pretrain_steps = pretrain_steps
self._ratio = ratio
self._prev = None

def __call__(self, step: int) -> int:
if self._ratio == 0:
return 0
if self._prev is None:
self._prev = step
repeats = 1
if self._pretrain_steps > 0:
if step < self._pretrain_steps:
warnings.warn(
"The number of pretrain steps is greater than the number of current steps. This could lead to "
f"a higher ratio than the one specified ({self._ratio}). Setting the 'pretrain_steps' equal to "
"the number of current steps."
)
self._pretrain_steps = step
repeats = round(self._pretrain_steps * self._ratio)
return repeats
repeats = round((step - self._prev) * self._ratio)
self._prev += repeats / self._ratio
return repeats

def state_dict(self) -> Dict[str, Any]:
return {"_ratio": self._ratio, "_prev": self._prev, "_pretrain_steps": self._pretrain_steps}

def load_state_dict(self, state_dict: Mapping[str, Any]):
self._ratio = state_dict["_ratio"]
self._prev = state_dict["_prev"]
self._pretrain_steps = state_dict["_pretrain_steps"]
return self


if __name__ == "__main__":
num_envs = 1
world_size = 1
replay_ratio = 0.5
per_rank_batch_size = 16
per_rank_sequence_length = 64
replayed_steps = world_size * per_rank_batch_size * per_rank_sequence_length
train_steps = 0
gradient_steps = 0
total_policy_steps = 2**10
r = Ratio(ratio=replay_ratio, pretrain_steps=0)
policy_steps = num_envs * world_size
printed = False
for i in range(0, total_policy_steps, policy_steps):
if i >= 128:
per_rank_repeats = r(i / world_size)
if per_rank_repeats > 0 and not printed:
print(
f"Training the agent with {per_rank_repeats} repeats on every rank "
f"({per_rank_repeats * world_size} global repeats) at global iteration {i}"
)
printed = True
gradient_steps += per_rank_repeats * world_size
print("Replay ratio", replay_ratio)
print("Hafner train ratio", replay_ratio * replayed_steps)
print("Final ratio", gradient_steps / total_policy_steps)
17 changes: 7 additions & 10 deletions howto/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ In the `algo` folder one can find all the configurations for every algorithm imp

```yaml
# sheeprl/configs/algo/dreamer_v3.yaml
# Dreamer-V3 XL configuration
defaults:
- default
- /optim@world_model.optimizer: adam
Expand All @@ -139,10 +141,8 @@ lmbda: 0.95
horizon: 15
# Training recipe
train_every: 16
learning_starts: 65536
per_rank_pretrain_steps: 1
per_rank_gradient_steps: 1
replay_ratio: 1
learning_starts: 1024
per_rank_sequence_length: ???
# Encoder and decoder keys
Expand All @@ -159,6 +159,7 @@ dense_act: torch.nn.SiLU
cnn_act: torch.nn.SiLU
unimix: 0.01
hafner_initialization: True
decoupled_rssm: False
# World model
world_model:
Expand Down Expand Up @@ -241,10 +242,6 @@ actor:
layer_norm: ${algo.layer_norm}
dense_units: ${algo.dense_units}
clip_gradients: 100.0
expl_amount: 0.0
expl_min: 0.0
expl_decay: False
max_step_expl_decay: 0
# Disttributed percentile model (used to scale the values)
moments:
Expand All @@ -266,7 +263,7 @@ critic:
mlp_layers: ${algo.mlp_layers}
layer_norm: ${algo.layer_norm}
dense_units: ${algo.dense_units}
target_network_update_freq: 1
per_rank_target_network_update_freq: 1
tau: 0.02
bins: 255
clip_gradients: 100.0
Expand Down Expand Up @@ -410,7 +407,7 @@ buffer:
algo:
learning_starts: 1024
total_steps: 100000
train_every: 1
dense_units: 512
mlp_layers: 2
world_model:
Expand Down
2 changes: 1 addition & 1 deletion notebooks/dreamer_v3_imagination.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@
" mask = {k: v for k, v in preprocessed_obs.items() if k.startswith(\"mask\")}\n",
" if len(mask) == 0:\n",
" mask = None\n",
" real_actions = actions = player.get_exploration_action(preprocessed_obs, mask)\n",
" real_actions = actions = player.get_actions(preprocessed_obs, mask)\n",
" actions = torch.cat(actions, -1).cpu().numpy()\n",
" if is_continuous:\n",
" real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()\n",
Expand Down
27 changes: 18 additions & 9 deletions sheeprl/algos/dreamer_v1/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def __init__(
encoder: nn.Module | _FabricModule,
recurrent_model: nn.Module | _FabricModule,
representation_model: nn.Module | _FabricModule,
actor: nn.Module | _FabricModule,
actor: Actor | _FabricModule,
actions_dim: Sequence[int],
num_envs: int,
stochastic_size: int,
Expand Down Expand Up @@ -288,32 +288,38 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
self.recurrent_state[:, reset_envs] = torch.zeros_like(self.recurrent_state[:, reset_envs])
self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs])

def get_exploration_action(self, obs: Tensor, mask: Optional[Dict[str, Tensor]] = None) -> Sequence[Tensor]:
def get_exploration_actions(
self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None, step: int = 0
) -> Sequence[Tensor]:
"""Return the actions with a certain amount of noise for exploration.
Args:
obs (Tensor): the current observations.
sample_actions (bool): whether or not to sample the actions.
Default to True.
mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed).
Defaults to None.
step (int): the step of the training, used for the exploration amount.
Default to 0.
Returns:
The actions the agent has to perform (Sequence[Tensor]).
"""
actions = self.get_greedy_action(obs, mask=mask)
actions = self.get_actions(obs, sample_actions=sample_actions, mask=mask)
expl_actions = None
if self.actor.expl_amount > 0:
expl_actions = self.actor.add_exploration_noise(actions, mask=mask)
if self.actor._expl_amount > 0:
expl_actions = self.actor.add_exploration_noise(actions, step=step, mask=mask)
self.actions = torch.cat(expl_actions, dim=-1)
return expl_actions or actions

def get_greedy_action(
self, obs: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None
def get_actions(
self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None
) -> Sequence[Tensor]:
"""Return the greedy actions.
Args:
obs (Tensor): the current observations.
is_training (bool): whether it is training.
sample_actions (bool): whether or not to sample the actions.
Default to True.
mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed).
Defaults to None.
Expand All @@ -329,7 +335,7 @@ def get_greedy_action(
self.representation_model(torch.cat((self.recurrent_state, embedded_obs), -1)),
validate_args=self.validate_args,
)
actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), is_training, mask)
actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask)
self.actions = torch.cat(actions, -1)
return actions

Expand Down Expand Up @@ -488,6 +494,9 @@ def build_agent(
activation=eval(actor_cfg.dense_act),
distribution_cfg=cfg.distribution,
layer_norm=False,
expl_amount=actor_cfg.expl_amount,
expl_decay=actor_cfg.expl_decay,
expl_min=actor_cfg.expl_min,
)
critic = MLP(
input_dims=latent_state_size,
Expand Down
81 changes: 37 additions & 44 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import polynomial_decay, save_configs
from sheeprl.utils.utils import Ratio, save_configs

# Decomment the following two lines if you cannot start an experiment with DMC environments
# os.environ["PYOPENGL_PLATFORM"] = ""
Expand Down Expand Up @@ -547,22 +547,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
last_log = state["last_log"] if cfg.checkpoint.resume_from else 0
last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0
policy_steps_per_update = int(cfg.env.num_envs * world_size)
updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0
num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1
learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0
expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0
max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size)
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
actor.expl_amount = polynomial_decay(
expl_decay_steps,
initial=cfg.algo.actor.expl_amount,
final=cfg.algo.actor.expl_min,
max_decay_steps=max_step_expl_decay,
)
if not cfg.buffer.checkpoint:
learning_starts += start_step

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
if cfg.checkpoint.resume_from:
ratio.load_state_dict(state["ratio"])

# Warning for log and checkpoint every
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0:
warnings.warn(
Expand Down Expand Up @@ -592,6 +588,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
rb.add(step_data, validate_args=cfg.buffer.validate_args)
player.init_states()

cumulative_per_rank_gradient_steps = 0
for update in range(start_step, num_updates + 1):
policy_step += cfg.env.num_envs * world_size

Expand Down Expand Up @@ -624,7 +621,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_exploration_action(normalized_obs, mask)
real_actions = actions = player.get_exploration_actions(normalized_obs, mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down Expand Up @@ -684,46 +681,37 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Reset internal agent states
player.init_states(reset_envs=dones_idxes)

updates_before_training -= 1

# Train the agent
if update > learning_starts and updates_before_training <= 0:
# Start training
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(cfg.algo.per_rank_gradient_steps):
if update >= learning_starts:
per_rank_gradient_steps = ratio(policy_step / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
n_samples=1,
n_samples=per_rank_gradient_steps,
dtype=None,
device=device,
from_numpy=cfg.buffer.from_numpy,
) # [N_samples, Seq_len, Batch_size, ...]
batch = {k: v[0].float() for k, v in sample.items()}
train(
fabric,
world_model,
actor,
critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
)
train_step += world_size
updates_before_training = cfg.algo.train_every // policy_steps_per_update
if cfg.algo.actor.expl_decay:
expl_decay_steps += 1
actor.expl_amount = polynomial_decay(
expl_decay_steps,
initial=cfg.algo.actor.expl_amount,
final=cfg.algo.actor.expl_min,
max_decay_steps=max_step_expl_decay,
)
if aggregator:
aggregator.update("Params/exploration_amount", actor.expl_amount)
for i in range(per_rank_gradient_steps):
batch = {k: v[i].float() for k, v in sample.items()}
train(
fabric,
world_model,
actor,
critic,
world_optimizer,
actor_optimizer,
critic_optimizer,
batch,
aggregator,
cfg,
)
cumulative_per_rank_gradient_steps += 1
train_step += world_size
if aggregator:
aggregator.update("Params/exploration_amount", actor._get_expl_amount(policy_step))

# Log metrics
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates):
Expand All @@ -733,6 +721,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.log_dict(metrics_dict, policy_step)
aggregator.reset()

# Log replay ratio
fabric.log(
"Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step
)

# Sync distributed timers
if not timer.disabled:
timer_metrics = timer.compute()
Expand Down Expand Up @@ -767,7 +760,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
"world_optimizer": world_optimizer.state_dict(),
"actor_optimizer": actor_optimizer.state_dict(),
"critic_optimizer": critic_optimizer.state_dict(),
"expl_decay_steps": expl_decay_steps,
"ratio": ratio.state_dict(),
"update": update * world_size,
"batch_size": cfg.algo.per_rank_batch_size * world_size,
"last_log": last_log,
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
"State/post_entropy",
"State/prior_entropy",
"State/kl",
"Params/exploration_amount",
"Grads/world_model",
"Grads/actor",
"Grads/critic",
"Params/exploration_amount",
}
MODELS_TO_REGISTER = {"world_model", "actor", "critic"}

Expand Down
Loading

0 comments on commit 3dc227b

Please sign in to comment.