diff --git a/examples/ratio.py b/examples/ratio.py new file mode 100644 index 00000000..03712916 --- /dev/null +++ b/examples/ratio.py @@ -0,0 +1,74 @@ +import warnings +from typing import Any, Dict, Mapping + + +class Ratio: + """Directly taken from Hafner et al. (2023) implementation: + https://github.com/danijar/dreamerv3/blob/8fa35f83eee1ce7e10f3dee0b766587d0a713a60/dreamerv3/embodied/core/when.py#L26 + """ + + def __init__(self, ratio: float, pretrain_steps: int = 0): + if pretrain_steps < 0: + raise ValueError(f"'pretrain_steps' must be non-negative, got {pretrain_steps}") + if ratio < 0: + raise ValueError(f"'ratio' must be non-negative, got {ratio}") + self._pretrain_steps = pretrain_steps + self._ratio = ratio + self._prev = None + + def __call__(self, step: int) -> int: + if self._ratio == 0: + return 0 + if self._prev is None: + self._prev = step + repeats = 1 + if self._pretrain_steps > 0: + if step < self._pretrain_steps: + warnings.warn( + "The number of pretrain steps is greater than the number of current steps. This could lead to " + f"a higher ratio than the one specified ({self._ratio}). Setting the 'pretrain_steps' equal to " + "the number of current steps." + ) + self._pretrain_steps = step + repeats = round(self._pretrain_steps * self._ratio) + return repeats + repeats = round((step - self._prev) * self._ratio) + self._prev += repeats / self._ratio + return repeats + + def state_dict(self) -> Dict[str, Any]: + return {"_ratio": self._ratio, "_prev": self._prev, "_pretrain_steps": self._pretrain_steps} + + def load_state_dict(self, state_dict: Mapping[str, Any]): + self._ratio = state_dict["_ratio"] + self._prev = state_dict["_prev"] + self._pretrain_steps = state_dict["_pretrain_steps"] + return self + + +if __name__ == "__main__": + num_envs = 1 + world_size = 1 + replay_ratio = 0.5 + per_rank_batch_size = 16 + per_rank_sequence_length = 64 + replayed_steps = world_size * per_rank_batch_size * per_rank_sequence_length + train_steps = 0 + gradient_steps = 0 + total_policy_steps = 2**10 + r = Ratio(ratio=replay_ratio, pretrain_steps=0) + policy_steps = num_envs * world_size + printed = False + for i in range(0, total_policy_steps, policy_steps): + if i >= 128: + per_rank_repeats = r(i / world_size) + if per_rank_repeats > 0 and not printed: + print( + f"Training the agent with {per_rank_repeats} repeats on every rank " + f"({per_rank_repeats * world_size} global repeats) at global iteration {i}" + ) + printed = True + gradient_steps += per_rank_repeats * world_size + print("Replay ratio", replay_ratio) + print("Hafner train ratio", replay_ratio * replayed_steps) + print("Final ratio", gradient_steps / total_policy_steps) diff --git a/howto/configs.md b/howto/configs.md index ea49072b..6a6d096c 100644 --- a/howto/configs.md +++ b/howto/configs.md @@ -126,6 +126,8 @@ In the `algo` folder one can find all the configurations for every algorithm imp ```yaml # sheeprl/configs/algo/dreamer_v3.yaml +# Dreamer-V3 XL configuration + defaults: - default - /optim@world_model.optimizer: adam @@ -139,10 +141,8 @@ lmbda: 0.95 horizon: 15 # Training recipe -train_every: 16 -learning_starts: 65536 -per_rank_pretrain_steps: 1 -per_rank_gradient_steps: 1 +replay_ratio: 1 +learning_starts: 1024 per_rank_sequence_length: ??? # Encoder and decoder keys @@ -159,6 +159,7 @@ dense_act: torch.nn.SiLU cnn_act: torch.nn.SiLU unimix: 0.01 hafner_initialization: True +decoupled_rssm: False # World model world_model: @@ -241,10 +242,6 @@ actor: layer_norm: ${algo.layer_norm} dense_units: ${algo.dense_units} clip_gradients: 100.0 - expl_amount: 0.0 - expl_min: 0.0 - expl_decay: False - max_step_expl_decay: 0 # Disttributed percentile model (used to scale the values) moments: @@ -266,7 +263,7 @@ critic: mlp_layers: ${algo.mlp_layers} layer_norm: ${algo.layer_norm} dense_units: ${algo.dense_units} - target_network_update_freq: 1 + per_rank_target_network_update_freq: 1 tau: 0.02 bins: 255 clip_gradients: 100.0 @@ -410,7 +407,7 @@ buffer: algo: learning_starts: 1024 total_steps: 100000 - train_every: 1 + dense_units: 512 mlp_layers: 2 world_model: diff --git a/notebooks/dreamer_v3_imagination.ipynb b/notebooks/dreamer_v3_imagination.ipynb index 1a7f2a44..c58531f2 100644 --- a/notebooks/dreamer_v3_imagination.ipynb +++ b/notebooks/dreamer_v3_imagination.ipynb @@ -230,7 +230,7 @@ " mask = {k: v for k, v in preprocessed_obs.items() if k.startswith(\"mask\")}\n", " if len(mask) == 0:\n", " mask = None\n", - " real_actions = actions = player.get_exploration_action(preprocessed_obs, mask)\n", + " real_actions = actions = player.get_actions(preprocessed_obs, mask)\n", " actions = torch.cat(actions, -1).cpu().numpy()\n", " if is_continuous:\n", " real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()\n", diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index d9b69d54..e40aa8f6 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -241,7 +241,7 @@ def __init__( encoder: nn.Module | _FabricModule, recurrent_model: nn.Module | _FabricModule, representation_model: nn.Module | _FabricModule, - actor: nn.Module | _FabricModule, + actor: Actor | _FabricModule, actions_dim: Sequence[int], num_envs: int, stochastic_size: int, @@ -288,32 +288,38 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: self.recurrent_state[:, reset_envs] = torch.zeros_like(self.recurrent_state[:, reset_envs]) self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs]) - def get_exploration_action(self, obs: Tensor, mask: Optional[Dict[str, Tensor]] = None) -> Sequence[Tensor]: + def get_exploration_actions( + self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None, step: int = 0 + ) -> Sequence[Tensor]: """Return the actions with a certain amount of noise for exploration. Args: obs (Tensor): the current observations. + sample_actions (bool): whether or not to sample the actions. + Default to True. mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed). Defaults to None. + step (int): the step of the training, used for the exploration amount. + Default to 0. Returns: The actions the agent has to perform (Sequence[Tensor]). """ - actions = self.get_greedy_action(obs, mask=mask) + actions = self.get_actions(obs, sample_actions=sample_actions, mask=mask) expl_actions = None - if self.actor.expl_amount > 0: - expl_actions = self.actor.add_exploration_noise(actions, mask=mask) + if self.actor._expl_amount > 0: + expl_actions = self.actor.add_exploration_noise(actions, step=step, mask=mask) self.actions = torch.cat(expl_actions, dim=-1) return expl_actions or actions - def get_greedy_action( - self, obs: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None + def get_actions( + self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None ) -> Sequence[Tensor]: """Return the greedy actions. Args: obs (Tensor): the current observations. - is_training (bool): whether it is training. + sample_actions (bool): whether or not to sample the actions. Default to True. mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed). Defaults to None. @@ -329,7 +335,7 @@ def get_greedy_action( self.representation_model(torch.cat((self.recurrent_state, embedded_obs), -1)), validate_args=self.validate_args, ) - actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), is_training, mask) + actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask) self.actions = torch.cat(actions, -1) return actions @@ -488,6 +494,9 @@ def build_agent( activation=eval(actor_cfg.dense_act), distribution_cfg=cfg.distribution, layer_norm=False, + expl_amount=actor_cfg.expl_amount, + expl_decay=actor_cfg.expl_decay, + expl_min=actor_cfg.expl_min, ) critic = MLP( input_dims=latent_state_size, diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 0326a142..29d52fb9 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -27,7 +27,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -547,22 +547,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 - expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -592,6 +588,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -624,7 +621,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(normalized_obs, mask) + real_actions = actions = player.get_exploration_actions(normalized_obs, mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -684,46 +681,37 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Reset internal agent states player.init_states(reset_envs=dones_idxes) - updates_before_training -= 1 - # Train the agent - if update > learning_starts and updates_before_training <= 0: - # Start training - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(cfg.algo.per_rank_gradient_steps): + if update >= learning_starts: + per_rank_gradient_steps = ratio(policy_step / world_size) + if per_rank_gradient_steps > 0: + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): sample = rb.sample_tensors( batch_size=cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=1, + n_samples=per_rank_gradient_steps, dtype=None, device=device, from_numpy=cfg.buffer.from_numpy, ) # [N_samples, Seq_len, Batch_size, ...] - batch = {k: v[0].float() for k, v in sample.items()} - train( - fabric, - world_model, - actor, - critic, - world_optimizer, - actor_optimizer, - critic_optimizer, - batch, - aggregator, - cfg, - ) - train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - if aggregator: - aggregator.update("Params/exploration_amount", actor.expl_amount) + for i in range(per_rank_gradient_steps): + batch = {k: v[i].float() for k, v in sample.items()} + train( + fabric, + world_model, + actor, + critic, + world_optimizer, + actor_optimizer, + critic_optimizer, + batch, + aggregator, + cfg, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size + if aggregator: + aggregator.update("Params/exploration_amount", actor._get_expl_amount(policy_step)) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -733,6 +721,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -767,7 +760,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "world_optimizer": world_optimizer.state_dict(), "actor_optimizer": actor_optimizer.state_dict(), "critic_optimizer": critic_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "last_log": last_log, diff --git a/sheeprl/algos/dreamer_v1/utils.py b/sheeprl/algos/dreamer_v1/utils.py index cbf372d0..1c91019d 100644 --- a/sheeprl/algos/dreamer_v1/utils.py +++ b/sheeprl/algos/dreamer_v1/utils.py @@ -31,10 +31,10 @@ "State/post_entropy", "State/prior_entropy", "State/kl", - "Params/exploration_amount", "Grads/world_model", "Grads/actor", "Grads/critic", + "Params/exploration_amount", } MODELS_TO_REGISTER = {"world_model", "actor", "critic"} diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index 6f546f43..63f766f3 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -435,6 +435,10 @@ class Actor(nn.Module): Default to False. expl_amount (float): the exploration amount to use during training. Default to 0.0. + expl_decay (float): the exploration decay to use during training. + Default to 0.0. + expl_min (float): the exploration amount minimum to use during training. + Default to 0.0. """ def __init__( @@ -450,6 +454,8 @@ def __init__( mlp_layers: int = 4, layer_norm: bool = False, expl_amount: float = 0.0, + expl_decay: float = 0.0, + expl_min: float = 0.0, ) -> None: super().__init__() self.distribution_cfg = distribution_cfg @@ -485,17 +491,17 @@ def __init__( self.min_std = min_std self.distribution_cfg = distribution_cfg self._expl_amount = expl_amount + self._expl_decay = expl_decay + self._expl_min = expl_min - @property - def expl_amount(self) -> float: - return self._expl_amount - - @expl_amount.setter - def expl_amount(self, amount: float): - self._expl_amount = amount + def _get_expl_amount(self, step: int) -> Tensor: + amount = self._expl_amount + if self._expl_decay: + amount *= 0.5 ** float(step) / self._expl_decay + return max(amount, self._expl_min) def forward( - self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -503,7 +509,7 @@ def forward( Args: state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). - is_training (bool): whether it is in the training phase. + sample_actions (bool): whether or not to sample the actions. Default to True. mask (Dict[str, Tensor], optional): the action mask (which actions can be selected). Default to None. @@ -534,7 +540,7 @@ def forward( std = 2 * torch.sigmoid((std + self.init_std) / 2) + self.min_std dist = TruncatedNormal(torch.tanh(mean), std, -1, 1, validate_args=self.distribution_cfg.validate_args) actions_dist = Independent(dist, 1, validate_args=self.distribution_cfg.validate_args) - if is_training: + if sample_actions: actions = actions_dist.rsample() else: sample = actions_dist.sample((100,)) @@ -551,19 +557,20 @@ def forward( logits=logits, validate_args=self.distribution_cfg.validate_args ) ) - if is_training: + if sample_actions: actions.append(actions_dist[-1].rsample()) else: actions.append(actions_dist[-1].mode) return tuple(actions), tuple(actions_dist) def add_exploration_noise( - self, actions: Sequence[Tensor], mask: Optional[Dict[str, Tensor]] = None + self, actions: Sequence[Tensor], step: int = 0, mask: Optional[Dict[str, Tensor]] = None ) -> Sequence[Tensor]: + expl_amount = self._get_expl_amount(step) if self.is_continuous: actions = torch.cat(actions, -1) - if self._expl_amount > 0.0: - actions = torch.clip(Normal(actions, self._expl_amount).sample(), -1, 1) + if expl_amount > 0.0: + actions = torch.clip(Normal(actions, expl_amount).sample(), -1, 1) expl_actions = [actions] else: expl_actions = [] @@ -574,7 +581,7 @@ def add_exploration_noise( .to(act.device) ) expl_actions.append( - torch.where(torch.rand(act.shape[:1], device=act.device) < self._expl_amount, sample, act) + torch.where(torch.rand(act.shape[:1], device=act.device) < expl_amount, sample, act) ) return tuple(expl_actions) @@ -593,6 +600,8 @@ def __init__( mlp_layers: int = 4, layer_norm: bool = False, expl_amount: float = 0.0, + expl_decay: float = 0.0, + expl_min: float = 0.0, ) -> None: super().__init__( latent_state_size=latent_state_size, @@ -606,10 +615,12 @@ def __init__( mlp_layers=mlp_layers, layer_norm=layer_norm, expl_amount=expl_amount, + expl_decay=expl_decay, + expl_min=expl_min, ) def forward( - self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -617,7 +628,7 @@ def forward( Args: state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). - is_training (bool): whether it is in the training phase. + sample_actions (bool): whether or not to sample the actions. Default to True. mask (Dict[str, Tensor], optional): the action mask (which actions can be selected). Default to None. @@ -657,7 +668,7 @@ def forward( logits=logits, validate_args=self.distribution_cfg.validate_args ) ) - if is_training: + if sample_actions: actions.append(actions_dist[-1].rsample()) else: actions.append(actions_dist[-1].mode) @@ -666,7 +677,7 @@ def forward( return tuple(actions), tuple(actions_dist) def add_exploration_noise( - self, actions: Sequence[Tensor], mask: Optional[Dict[str, Tensor]] = None + self, actions: Sequence[Tensor], step: int = 0, mask: Optional[Dict[str, Tensor]] = None ) -> Sequence[Tensor]: expl_actions = [] functional_action = actions[0].argmax(dim=-1) @@ -696,7 +707,7 @@ def add_exploration_noise( sample = ( OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False).sample().to(act.device) ) - expl_amount = self.expl_amount + expl_amount = self._get_expl_amount(step) # If the action[0] was changed, and now it is critical, then we force to change also the other 2 actions # to satisfy the constraints of the environment if ( @@ -816,30 +827,10 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: self.recurrent_state[:, reset_envs] = torch.zeros_like(self.recurrent_state[:, reset_envs]) self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs]) - def get_exploration_action(self, obs: Dict[str, Tensor], mask: Optional[Dict[str, Tensor]] = None) -> Tensor: - """ - Return the actions with a certain amount of noise for exploration. - - Args: - obs (Dict[str, Tensor]): the current observations. - is_continuous (bool): whether or not the actions are continuous. - mask (Dict[str, Tensor], optional): the action mask (which actions can be selected). - Default to None. - - Returns: - The actions the agent has to perform. - """ - actions = self.get_greedy_action(obs, mask=mask) - expl_actions = None - if self.actor.expl_amount > 0: - expl_actions = self.actor.add_exploration_noise(actions, mask=mask) - self.actions = torch.cat(expl_actions, dim=-1) - return expl_actions or actions - - def get_greedy_action( + def get_actions( self, obs: Dict[str, Tensor], - is_training: bool = True, + sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None, ) -> Sequence[Tensor]: """ @@ -847,7 +838,7 @@ def get_greedy_action( Args: obs (Dict[str, Tensor]): the current observations. - is_training (bool): whether it is training. + sample_actions (bool): whether or not to sample the actions. Default to True. mask (Dict[str, Tensor], optional): the action mask (which actions can be selected). Default to None. @@ -866,7 +857,7 @@ def get_greedy_action( self.stochastic_state = stochastic_state.view( *stochastic_state.shape[:-2], self.stochastic_size * self.discrete_size ) - actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), is_training, mask) + actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask) self.actions = torch.cat(actions, -1) return actions diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 613ce927..159e4869 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -32,7 +32,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -579,22 +579,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -625,7 +621,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() - per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -658,7 +654,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(normalized_obs, mask) + real_actions = actions = player.get_actions(normalized_obs, mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -720,54 +716,43 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Reset internal agent states player.init_states(dones_idxes) - updates_before_training -= 1 - # Train the agent - if update >= learning_starts and updates_before_training <= 0: - n_samples = ( - cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps - ) - local_data = rb.sample_tensors( - batch_size=cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=n_samples, - dtype=None, - device=fabric.device, - from_numpy=cfg.buffer.from_numpy, - ) - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(next(iter(local_data.values())).shape[0]): - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - for cp, tcp in zip(critic.module.parameters(), target_critic.module.parameters()): - tcp.data.copy_(cp.data) - batch = {k: v[i].float() for k, v in local_data.items()} - train( - fabric, - world_model, - actor, - critic, - target_critic, - world_optimizer, - actor_optimizer, - critic_optimizer, - batch, - aggregator, - cfg, - actions_dim, - ) - per_rank_gradient_steps += 1 - train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, + if update >= learning_starts: + per_rank_gradient_steps = ratio(policy_step / world_size) + if per_rank_gradient_steps > 0: + local_data = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=fabric.device, + from_numpy=cfg.buffer.from_numpy, ) - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount", actor.expl_amount) + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + for i in range(per_rank_gradient_steps): + if ( + cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq + == 0 + ): + for cp, tcp in zip(critic.module.parameters(), target_critic.module.parameters()): + tcp.data.copy_(cp.data) + batch = {k: v[i].float() for k, v in local_data.items()} + train( + fabric, + world_model, + actor, + critic, + target_critic, + world_optimizer, + actor_optimizer, + critic_optimizer, + batch, + aggregator, + cfg, + actions_dim, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -777,6 +762,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -812,7 +802,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "world_optimizer": world_optimizer.state_dict(), "actor_optimizer": actor_optimizer.state_dict(), "critic_optimizer": critic_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "last_log": last_log, diff --git a/sheeprl/algos/dreamer_v2/utils.py b/sheeprl/algos/dreamer_v2/utils.py index 799674fc..d64134e9 100644 --- a/sheeprl/algos/dreamer_v2/utils.py +++ b/sheeprl/algos/dreamer_v2/utils.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence import gymnasium as gym import torch @@ -34,7 +34,6 @@ "State/post_entropy", "State/prior_entropy", "State/kl", - "Params/exploration_amount", "Grads/world_model", "Grads/actor", "Grads/critic", @@ -107,7 +106,7 @@ def compute_lambda_values( @torch.no_grad() def test( - player: Union["PlayerDV2", "PlayerDV1"], + player: "PlayerDV2" | "PlayerDV1", fabric: Fabric, cfg: Dict[str, Any], log_dir: str, @@ -143,7 +142,7 @@ def test( preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 elif k in cfg.algo.mlp_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) - real_actions = player.get_greedy_action( + real_actions = player.get_actions( preprocessed_obs, sample_actions, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} ) if player.actor.is_continuous: diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index f29546dc..b7fa0065 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -18,11 +18,7 @@ from sheeprl.algos.dreamer_v2.utils import compute_stochastic_state from sheeprl.algos.dreamer_v3.utils import init_weights, uniform_init_weights from sheeprl.models.models import CNN, MLP, DeCNN, LayerNormGRUCell, MultiDecoder, MultiEncoder -from sheeprl.utils.distribution import ( - OneHotCategoricalStraightThroughValidateArgs, - OneHotCategoricalValidateArgs, - TruncatedNormal, -) +from sheeprl.utils.distribution import OneHotCategoricalStraightThroughValidateArgs, TruncatedNormal from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.model import LayerNormChannelLast, ModuleType, cnn_forward from sheeprl.utils.utils import symlog @@ -641,29 +637,10 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: self.recurrent_state[:, reset_envs], sample_state=False )[1].reshape(1, len(reset_envs), -1) - def get_exploration_action(self, obs: Dict[str, Tensor], mask: Optional[Dict[str, Tensor]] = None) -> Tensor: - """ - Return the actions with a certain amount of noise for exploration. - - Args: - obs (Dict[str, Tensor]): the current observations. - mask (Dict[str, Tensor], optional): the mask of the actions. - Default to None. - - Returns: - The actions the agent has to perform. - """ - actions = self.get_greedy_action(obs, mask=mask) - expl_actions = None - if self.actor.expl_amount > 0: - expl_actions = self.actor.add_exploration_noise(actions, mask=mask) - self.actions = torch.cat(expl_actions, dim=-1) - return expl_actions or actions - - def get_greedy_action( + def get_actions( self, obs: Dict[str, Tensor], - is_training: bool = True, + sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None, ) -> Sequence[Tensor]: """ @@ -671,7 +648,7 @@ def get_greedy_action( Args: obs (Dict[str, Tensor]): the current observations. - is_training (bool): whether it is training. + sample_actions (bool): whether or not to sample the actions. Default to True. Returns: @@ -688,7 +665,7 @@ def get_greedy_action( self.stochastic_state = self.stochastic_state.view( *self.stochastic_state.shape[:-2], self.stochastic_size * self.discrete_size ) - actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), is_training, mask) + actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask) self.actions = torch.cat(actions, -1) return actions @@ -720,8 +697,6 @@ class Actor(nn.Module): then `p = (1 - self.unimix) * p + self.unimix * unif`, where `unif = `1 / self.discrete`. Defaults to 0.01. - expl_amount (float): the exploration amount to use during training. - Default to 0.0. """ def __init__( @@ -737,7 +712,6 @@ def __init__( mlp_layers: int = 5, layer_norm: bool = True, unimix: float = 0.01, - expl_amount: float = 0.0, ) -> None: super().__init__() self.distribution_cfg = distribution_cfg @@ -775,18 +749,9 @@ def __init__( self.init_std = torch.tensor(init_std) self.min_std = min_std self._unimix = unimix - self._expl_amount = expl_amount - - @property - def expl_amount(self) -> float: - return self._expl_amount - - @expl_amount.setter - def expl_amount(self, amount: float): - self._expl_amount = amount def forward( - self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -794,6 +759,10 @@ def forward( Args: state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). + sample_actions (bool): whether or not to sample the actions. + Default to True. + mask (Dict[str, Tensor], optional): the mask to use on the actions. + Default to None. Returns: The tensor of the actions taken by the agent with shape (batch_size, *, num_actions). @@ -821,7 +790,7 @@ def forward( std = 2 * torch.sigmoid((std + self.init_std) / 2) + self.min_std dist = TruncatedNormal(torch.tanh(mean), std, -1, 1, validate_args=self.distribution_cfg.validate_args) actions_dist = Independent(dist, 1, validate_args=self.distribution_cfg.validate_args) - if is_training: + if sample_actions: actions = actions_dist.rsample() else: sample = actions_dist.sample((100,)) @@ -838,7 +807,7 @@ def forward( logits=self._uniform_mix(logits), validate_args=self.distribution_cfg.validate_args ) ) - if is_training: + if sample_actions: actions.append(actions_dist[-1].rsample()) else: actions.append(actions_dist[-1].mode) @@ -852,27 +821,6 @@ def _uniform_mix(self, logits: Tensor) -> Tensor: logits = probs_to_logits(probs) return logits - def add_exploration_noise( - self, actions: Sequence[Tensor], mask: Optional[Dict[str, Tensor]] = None - ) -> Sequence[Tensor]: - if self.is_continuous: - actions = torch.cat(actions, -1) - if self._expl_amount > 0.0: - actions = torch.clip(Normal(actions, self._expl_amount).sample(), -1, 1) - expl_actions = [actions] - else: - expl_actions = [] - for act in actions: - sample = ( - OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False) - .sample() - .to(act.device) - ) - expl_actions.append( - torch.where(torch.rand(act.shape[:1], device=act.device) < self._expl_amount, sample, act) - ) - return tuple(expl_actions) - class MinedojoActor(Actor): def __init__( @@ -888,7 +836,6 @@ def __init__( mlp_layers: int = 5, layer_norm: bool = True, unimix: float = 0.01, - expl_amount: float = 0.0, ) -> None: super().__init__( latent_state_size=latent_state_size, @@ -902,11 +849,10 @@ def __init__( mlp_layers=mlp_layers, layer_norm=layer_norm, unimix=unimix, - expl_amount=expl_amount, ) def forward( - self, state: Tensor, is_training: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -914,6 +860,10 @@ def forward( Args: state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). + sample_actions (bool): whether or not to sample the actions. + Default to True. + mask (Dict[str, Tensor], optional): the mask to apply to the actions. + Default to None. Returns: The tensor of the actions taken by the agent with shape (batch_size, *, num_actions). @@ -950,7 +900,7 @@ def forward( logits=logits, validate_args=self.distribution_cfg.validate_args ) ) - if is_training: + if sample_actions: actions.append(actions_dist[-1].rsample()) else: actions.append(actions_dist[-1].mode) @@ -958,51 +908,6 @@ def forward( functional_action = actions[0].argmax(dim=-1) # [T, B] return tuple(actions), tuple(actions_dist) - def add_exploration_noise( - self, actions: Sequence[Tensor], mask: Optional[Dict[str, Tensor]] = None - ) -> Sequence[Tensor]: - expl_actions = [] - functional_action = actions[0].argmax(dim=-1) - for i, act in enumerate(actions): - logits = torch.zeros_like(act) - # Exploratory action must respect the constraints of the environment - if mask is not None: - if i == 0: - logits[torch.logical_not(mask["mask_action_type"].expand_as(logits))] = -torch.inf - elif i == 1: - mask["mask_craft_smelt"] = mask["mask_craft_smelt"].expand_as(logits) - for t in range(functional_action.shape[0]): - for b in range(functional_action.shape[1]): - sampled_action = functional_action[t, b].item() - if sampled_action == 15: # Craft action - logits[t, b][torch.logical_not(mask["mask_craft_smelt"][t, b])] = -torch.inf - elif i == 2: - mask["mask_destroy"][t, b] = mask["mask_destroy"].expand_as(logits) - mask["mask_equip_place"] = mask["mask_equip_place"].expand_as(logits) - for t in range(functional_action.shape[0]): - for b in range(functional_action.shape[1]): - sampled_action = functional_action[t, b].item() - if sampled_action in {16, 17}: # Equip/Place action - logits[t, b][torch.logical_not(mask["mask_equip_place"][t, b])] = -torch.inf - elif sampled_action == 18: # Destroy action - logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf - sample = ( - OneHotCategoricalValidateArgs(logits=torch.zeros_like(act), validate_args=False).sample().to(act.device) - ) - expl_amount = self.expl_amount - # If the action[0] was changed, and now it is critical, then we force to change also the other 2 actions - # to satisfy the constraints of the environment - if ( - i in {1, 2} - and actions[0].argmax() != expl_actions[0].argmax() - and expl_actions[0].argmax().item() in {15, 16, 17, 18} - ): - expl_amount = 2 - expl_actions.append(torch.where(torch.rand(act.shape[:1], device=self.device) < expl_amount, sample, act)) - if mask is not None and i == 0: - functional_action = expl_actions[0].argmax(dim=-1) - return tuple(expl_actions) - def build_agent( fabric: Fabric, diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 7984e85c..432572e7 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -39,7 +39,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -524,7 +524,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") - expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables train_step = 0 @@ -541,21 +540,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size - actor.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -582,7 +578,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data["is_first"] = np.ones_like(step_data["dones"]) player.init_states() - per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -614,7 +610,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) + real_actions = actions = player.get_actions(preprocessed_obs, mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() @@ -688,56 +684,46 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) - updates_before_training -= 1 - # Train the agent - if update >= learning_starts and updates_before_training <= 0: - local_data = rb.sample_tensors( - cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=( - cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps - ), - dtype=None, - device=fabric.device, - from_numpy=cfg.buffer.from_numpy, - ) - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(next(iter(local_data.values())).shape[0]): - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau - for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): - tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) - batch = {k: v[i].float() for k, v in local_data.items()} - train( - fabric, - world_model, - actor, - critic, - target_critic, - world_optimizer, - actor_optimizer, - critic_optimizer, - batch, - aggregator, - cfg, - is_continuous, - actions_dim, - moments, - ) - per_rank_gradient_steps += 1 - train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, + if update >= learning_starts: + per_rank_gradient_steps = ratio(policy_step / world_size) + if per_rank_gradient_steps > 0: + local_data = rb.sample_tensors( + cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=fabric.device, + from_numpy=cfg.buffer.from_numpy, ) - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount", actor.expl_amount) + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + for i in range(per_rank_gradient_steps): + if ( + cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq + == 0 + ): + tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau + for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): + tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) + batch = {k: v[i].float() for k, v in local_data.items()} + train( + fabric, + world_model, + actor, + critic, + target_critic, + world_optimizer, + actor_optimizer, + critic_optimizer, + batch, + aggregator, + cfg, + is_continuous, + actions_dim, + moments, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -747,6 +733,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -782,8 +773,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "world_optimizer": world_optimizer.state_dict(), "actor_optimizer": actor_optimizer.state_dict(), "critic_optimizer": critic_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, "moments": moments.state_dict(), + "ratio": ratio.state_dict(), "update": update * fabric.world_size, "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, diff --git a/sheeprl/algos/dreamer_v3/utils.py b/sheeprl/algos/dreamer_v3/utils.py index d18d1fb0..efd10ac3 100644 --- a/sheeprl/algos/dreamer_v3/utils.py +++ b/sheeprl/algos/dreamer_v3/utils.py @@ -30,7 +30,6 @@ "State/kl", "State/post_entropy", "State/prior_entropy", - "Params/exploration_amount", "Grads/world_model", "Grads/actor", "Grads/critic", @@ -116,7 +115,7 @@ def test( preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 elif k in cfg.algo.mlp_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) - real_actions = player.get_greedy_action( + real_actions = player.get_actions( preprocessed_obs, sample_actions, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} ) if player.actor.is_continuous: diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 2d7f8c1a..186a905f 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -26,7 +26,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import save_configs +from sheeprl.utils.utils import Ratio, save_configs def train( @@ -38,17 +38,22 @@ def train( rb: ReplayBuffer, aggregator: MetricAggregator | None, cfg: Dict[str, Any], + per_rank_gradient_steps: int, ): # Sample a minibatch in a distributed way: Line 5 - Algorithm 2 # We sample one time to reduce the communications between processes sample = rb.sample_tensors( - cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size, + per_rank_gradient_steps * cfg.algo.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs, from_numpy=cfg.buffer.from_numpy, ) - critic_data = fabric.all_gather(sample) - flatten_dim = 3 if fabric.world_size > 1 else 2 - critic_data = {k: v.view(-1, *v.shape[flatten_dim:]) for k, v in critic_data.items()} + critic_data: Dict[str, torch.Tensor] = fabric.all_gather(sample) # [World, G*B] + for k, v in critic_data.items(): + critic_data[k] = v.float() # [G*B*World] + if fabric.world_size > 1: + critic_data[k] = critic_data[k].flatten(start_dim=0, end_dim=2) + else: + critic_data[k] = critic_data[k].flatten(start_dim=0, end_dim=1) critic_idxes = range(len(critic_data[next(iter(critic_data.keys()))])) if fabric.world_size > 1: dist_sampler: DistributedSampler = DistributedSampler( @@ -68,7 +73,12 @@ def train( # Sample a different minibatch in a distributed way to update actor and alpha parameter sample = rb.sample_tensors(cfg.algo.per_rank_batch_size, from_numpy=cfg.buffer.from_numpy) actor_data = fabric.all_gather(sample) - actor_data = {k: v.view(-1, *v.shape[flatten_dim:]) for k, v in actor_data.items()} + for k, v in actor_data.items(): + actor_data[k] = v.float() # [G*B*World] + if fabric.world_size > 1: + actor_data[k] = actor_data[k].flatten(start_dim=0, end_dim=2) + else: + actor_data[k] = actor_data[k].flatten(start_dim=0, end_dim=1) if fabric.world_size > 1: actor_sampler: DistributedSampler = DistributedSampler( range(len(actor_data[next(iter(actor_data.keys()))])), @@ -259,6 +269,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -280,16 +295,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): o = envs.reset(seed=cfg.seed)[0] obs = np.concatenate([o[k] for k in cfg.algo.mlp_keys.encoder], axis=-1).astype(np.float32) + per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * fabric.world_size # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - with torch.no_grad(): - # Sample an action given the observation received by the environment - actions, _ = actor(torch.from_numpy(obs).to(device)) - actions = actions.cpu().numpy() + if update <= learning_starts: + actions = envs.action_space.sample() + else: + with torch.inference_mode(): + # Sample an action given the observation received by the environment + actions, _ = actor(torch.from_numpy(obs).to(device)) + actions = actions.cpu().numpy() next_obs, rewards, dones, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) dones = np.logical_or(dones, truncated) @@ -328,9 +348,22 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = next_obs # Train the agent - if update > learning_starts: - train(fabric, agent, actor_optimizer, qf_optimizer, alpha_optimizer, rb, aggregator, cfg) - train_step += world_size + if update >= learning_starts: + per_rank_gradient_steps = ratio(policy_step / world_size) + if per_rank_gradient_steps > 0: + train( + fabric, + agent, + actor_optimizer, + qf_optimizer, + alpha_optimizer, + rb, + aggregator, + cfg, + per_rank_gradient_steps, + ) + train_step += world_size + cumulative_per_rank_gradient_steps += per_rank_gradient_steps # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -340,6 +373,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -372,6 +410,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "qf_optimizer": qf_optimizer.state_dict(), "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), + "ratio": ratio.state_dict(), "update": update * fabric.world_size, "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, diff --git a/sheeprl/algos/p2e_dv1/agent.py b/sheeprl/algos/p2e_dv1/agent.py index 7d13b21d..32a6269d 100644 --- a/sheeprl/algos/p2e_dv1/agent.py +++ b/sheeprl/algos/p2e_dv1/agent.py @@ -95,6 +95,9 @@ def build_agent( activation=eval(actor_cfg.dense_act), distribution_cfg=cfg.distribution, layer_norm=False, + expl_amount=actor_cfg.expl_amount, + expl_decay=actor_cfg.expl_decay, + expl_min=actor_cfg.expl_min, ) critic_task = MLP( input_dims=latent_state_size, diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 9ad04faa..6e9651d1 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -29,7 +29,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -557,7 +557,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") step_data = {} - expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables train_step = 0 @@ -574,27 +573,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -624,6 +614,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -656,7 +647,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(normalized_obs, mask) + real_actions = actions = player.get_exploration_actions(normalized_obs, mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -716,59 +707,47 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Reset internal agent states player.init_states(reset_envs=dones_idxes) - updates_before_training -= 1 - # Train the agent - if update >= learning_starts and updates_before_training <= 0: - # Start training - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(cfg.algo.per_rank_gradient_steps): + if update >= learning_starts: + per_rank_gradient_steps = ratio(policy_step / world_size) + if per_rank_gradient_steps > 0: + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): sample = rb.sample_tensors( batch_size=cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=1, + n_samples=per_rank_gradient_steps, dtype=None, device=device, from_numpy=cfg.buffer.from_numpy, ) # [N_samples, Seq_len, Batch_size, ...] - batch = {k: v[0].float() for k, v in sample.items()} - train( - fabric, - world_model, - actor_task, - critic_task, - world_optimizer, - actor_task_optimizer, - critic_task_optimizer, - batch, - aggregator, - cfg, - ensembles=ensembles, - ensemble_optimizer=ensemble_optimizer, - actor_exploration=actor_exploration, - critic_exploration=critic_exploration, - actor_exploration_optimizer=actor_exploration_optimizer, - critic_exploration_optimizer=critic_exploration_optimizer, + for i in range(per_rank_gradient_steps): + batch = {k: v[i].float() for k, v in sample.items()} + train( + fabric, + world_model, + actor_task, + critic_task, + world_optimizer, + actor_task_optimizer, + critic_task_optimizer, + batch, + aggregator, + cfg, + ensembles=ensembles, + ensemble_optimizer=ensemble_optimizer, + actor_exploration=actor_exploration, + critic_exploration=critic_exploration, + actor_exploration_optimizer=actor_exploration_optimizer, + critic_exploration_optimizer=critic_exploration_optimizer, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size + + if aggregator and not aggregator.disabled: + aggregator.update("Params/exploration_amount_task", actor_task._get_expl_amount(policy_step)) + aggregator.update( + "Params/exploration_amount_exploration", actor_exploration._get_expl_amount(policy_step) ) - train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount_task", actor_task.expl_amount) - aggregator.update("Params/exploration_amount_exploration", actor_exploration.expl_amount) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -778,6 +757,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -814,7 +798,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), "ensemble_optimizer": ensemble_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 6da54a91..1919d825 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -24,7 +24,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -201,7 +201,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - expl_decay_steps = state["expl_decay_steps"] if resume_from_checkpoint else 0 # Global variables train_step = 0 @@ -218,27 +217,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): last_log = state["last_log"] if resume_from_checkpoint else 0 last_checkpoint = state["last_checkpoint"] if resume_from_checkpoint else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if resume_from_checkpoint: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if resume_from_checkpoint and not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -268,6 +258,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -284,7 +275,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(normalized_obs, mask) + real_actions = actions = player.get_exploration_actions(normalized_obs, mask, step=policy_step) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -344,55 +335,43 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Reset internal agent states player.init_states(reset_envs=dones_idxes) - updates_before_training -= 1 - # Train the agent - if update >= learning_starts and updates_before_training <= 0: - if player.actor_type == "exploration": - player.actor_type = "task" - player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(cfg.algo.per_rank_gradient_steps): + if update >= learning_starts: + per_rank_gradient_steps = ratio(policy_step / world_size) + if per_rank_gradient_steps > 0: + if player.actor_type != "task": + player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): sample = rb.sample_tensors( batch_size=cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=1, + n_samples=per_rank_gradient_steps, dtype=None, device=device, from_numpy=cfg.buffer.from_numpy, ) # [N_samples, Seq_len, Batch_size, ...] - batch = {k: v[0].float() for k, v in sample.items()} - train( - fabric, - world_model, - actor_task, - critic_task, - world_optimizer, - actor_task_optimizer, - critic_task_optimizer, - batch, - aggregator, - cfg, + for i in range(per_rank_gradient_steps): + batch = {k: v[i].float() for k, v in sample.items()} + train( + fabric, + world_model, + actor_task, + critic_task, + world_optimizer, + actor_task_optimizer, + critic_task_optimizer, + batch, + aggregator, + cfg, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size + if aggregator and not aggregator.disabled: + aggregator.update("Params/exploration_amount_task", actor_task._get_expl_amount(policy_step)) + aggregator.update( + "Params/exploration_amount_exploration", actor_exploration._get_expl_amount(policy_step) ) - train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount_task", actor_task.expl_amount) - aggregator.update("Params/exploration_amount_exploration", actor_exploration.expl_amount) # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -402,6 +381,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -436,7 +420,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): "world_optimizer": world_optimizer.state_dict(), "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index 54fad041..df82cbec 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -29,7 +29,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -103,7 +103,6 @@ def train( critic_exploration_optimizer (_FabricOptimizer): the optimizer of the critic for exploration. is_continuous (bool): whether or not are continuous actions. actions_dim (Sequence[int]): the actions dimension. - is_exploring (bool): whether the agent is exploring. """ batch_size = cfg.algo.per_rank_batch_size sequence_length = cfg.algo.per_rank_sequence_length @@ -702,7 +701,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables train_step = 0 @@ -719,27 +717,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -770,7 +759,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() - per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -803,7 +792,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(normalized_obs, mask) + real_actions = actions = player.get_actions(normalized_obs, mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -865,73 +854,56 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Reset internal agent states player.init_states(dones_idxes) - updates_before_training -= 1 - # Train the agent - if update >= learning_starts and updates_before_training <= 0: - n_samples = ( - cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps - ) - local_data = rb.sample_tensors( - batch_size=cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=n_samples, - dtype=None, - device=fabric.device, - from_numpy=cfg.buffer.from_numpy, - ) - # Start training - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(next(iter(local_data.values())).shape[0]): - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): - tcp.data.copy_(cp.data) - for cp, tcp in zip( - critic_exploration.module.parameters(), target_critic_exploration.parameters() - ): - tcp.data.copy_(cp.data) - batch = {k: v[i].float() for k, v in local_data.items()} - train( - fabric, - world_model, - actor_task, - critic_task, - target_critic_task, - world_optimizer, - actor_task_optimizer, - critic_task_optimizer, - batch, - aggregator, - cfg, - ensembles=ensembles, - ensemble_optimizer=ensemble_optimizer, - actor_exploration=actor_exploration, - critic_exploration=critic_exploration, - target_critic_exploration=target_critic_exploration, - actor_exploration_optimizer=actor_exploration_optimizer, - critic_exploration_optimizer=critic_exploration_optimizer, - is_continuous=is_continuous, - actions_dim=actions_dim, - ) - train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, + if update >= learning_starts: + per_rank_gradient_steps = ratio(policy_step / world_size) + if per_rank_gradient_steps > 0: + local_data = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=fabric.device, + from_numpy=cfg.buffer.from_numpy, ) - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount_task", actor_task.expl_amount) - aggregator.update("Params/exploration_amount_exploration", actor_exploration.expl_amount) + # Start training + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + for i in range(per_rank_gradient_steps): + if ( + cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq + == 0 + ): + for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): + tcp.data.copy_(cp.data) + for cp, tcp in zip( + critic_exploration.module.parameters(), target_critic_exploration.parameters() + ): + tcp.data.copy_(cp.data) + batch = {k: v[i].float() for k, v in local_data.items()} + train( + fabric, + world_model, + actor_task, + critic_task, + target_critic_task, + world_optimizer, + actor_task_optimizer, + critic_task_optimizer, + batch, + aggregator, + cfg, + ensembles=ensembles, + ensemble_optimizer=ensemble_optimizer, + actor_exploration=actor_exploration, + critic_exploration=critic_exploration, + target_critic_exploration=target_critic_exploration, + actor_exploration_optimizer=actor_exploration_optimizer, + critic_exploration_optimizer=critic_exploration_optimizer, + is_continuous=is_continuous, + actions_dim=actions_dim, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -941,6 +913,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -978,7 +955,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), "ensemble_optimizer": ensemble_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index d00f8c69..24b0929b 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -24,7 +24,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -220,7 +220,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - expl_decay_steps = state["expl_decay_steps"] if resume_from_checkpoint else 0 # Global variables train_step = 0 @@ -237,27 +236,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): last_log = state["last_log"] if resume_from_checkpoint else 0 last_checkpoint = state["last_checkpoint"] if resume_from_checkpoint else 0 policy_steps_per_update = int(cfg.env.num_envs * world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if resume_from_checkpoint: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -288,7 +278,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() - per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -305,7 +295,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(normalized_obs, mask) + real_actions = actions = player.get_actions(normalized_obs, mask) actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() @@ -367,64 +357,47 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Reset internal agent states player.init_states(dones_idxes) - updates_before_training -= 1 - # Train the agent - if update >= learning_starts and updates_before_training <= 0: - if player.actor_type == "exploration": - player.actor_type = "task" - player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) - n_samples = ( - cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps - ) - local_data = rb.sample_tensors( - batch_size=cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=n_samples, - dtype=None, - device=fabric.device, - from_numpy=cfg.buffer.from_numpy, - ) - # Start training - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(next(iter(local_data.values())).shape[0]): - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): - tcp.data.copy_(cp.data) - batch = {k: v[i].float() for k, v in local_data.items()} - train( - fabric, - world_model, - actor_task, - critic_task, - target_critic_task, - world_optimizer, - actor_task_optimizer, - critic_task_optimizer, - batch, - aggregator, - cfg, - actions_dim=actions_dim, - ) - train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, + if update >= learning_starts: + per_rank_gradient_steps = ratio(policy_step / world_size) + if per_rank_gradient_steps > 0: + if player.actor_type != "task": + player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) + local_data = rb.sample_tensors( + batch_size=cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=fabric.device, + from_numpy=cfg.buffer.from_numpy, ) - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount_task", actor_task.expl_amount) - aggregator.update("Params/exploration_amount_exploration", actor_exploration.expl_amount) + # Start training + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + for i in range(per_rank_gradient_steps): + if ( + cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq + == 0 + ): + for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): + tcp.data.copy_(cp.data) + batch = {k: v[i].float() for k, v in local_data.items()} + train( + fabric, + world_model, + actor_task, + critic_task, + target_critic_task, + world_optimizer, + actor_task_optimizer, + critic_task_optimizer, + batch, + aggregator, + cfg, + actions_dim=actions_dim, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -434,6 +407,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -469,7 +447,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): "world_optimizer": world_optimizer.state_dict(), "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), diff --git a/sheeprl/algos/p2e_dv2/utils.py b/sheeprl/algos/p2e_dv2/utils.py index c717ce2d..91847673 100644 --- a/sheeprl/algos/p2e_dv2/utils.py +++ b/sheeprl/algos/p2e_dv2/utils.py @@ -29,8 +29,6 @@ "State/kl", "State/post_entropy", "State/prior_entropy", - "Params/exploration_amount_task", - "Params/exploration_amount_exploration", "Rewards/intrinsic", "Values_exploration/predicted_values", "Values_exploration/lambda_values", diff --git a/sheeprl/algos/p2e_dv3/evaluate.py b/sheeprl/algos/p2e_dv3/evaluate.py index 9dbf8f7a..e59052b7 100644 --- a/sheeprl/algos/p2e_dv3/evaluate.py +++ b/sheeprl/algos/p2e_dv3/evaluate.py @@ -61,7 +61,6 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): world_model.rssm, actor, actions_dim, - cfg.algo.player.expl_amount, cfg.env.num_envs, cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 106cb655..b11dbae2 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -33,7 +33,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -768,7 +768,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables train_step = 0 @@ -785,27 +784,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -832,7 +822,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data["is_first"] = np.ones_like(step_data["dones"]) player.init_states() - per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -864,7 +854,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) + real_actions = actions = player.get_actions(preprocessed_obs, mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() @@ -938,75 +928,59 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) - updates_before_training -= 1 - # Train the agent - if update >= learning_starts and updates_before_training <= 0: - local_data = rb.sample_tensors( - cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=( - cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps - ), - dtype=None, - device=fabric.device, - from_numpy=cfg.buffer.from_numpy, - ) - # Start training - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(next(iter(local_data.values())).shape[0]): - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau - for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): - tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) - for k in critics_exploration.keys(): - for cp, tcp in zip( - critics_exploration[k]["module"].module.parameters(), - critics_exploration[k]["target_module"].parameters(), - ): - tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) - batch = {k: v[i].float() for k, v in local_data.items()} - train( - fabric, - world_model, - actor_task, - critic_task, - target_critic_task, - world_optimizer, - actor_task_optimizer, - critic_task_optimizer, - batch, - aggregator, - cfg, - ensembles=ensembles, - ensemble_optimizer=ensemble_optimizer, - actor_exploration=actor_exploration, - critics_exploration=critics_exploration, - actor_exploration_optimizer=actor_exploration_optimizer, - is_continuous=is_continuous, - actions_dim=actions_dim, - moments_exploration=moments_exploration, - moments_task=moments_task, - ) - train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, + if update >= learning_starts: + per_rank_gradient_steps = ratio(policy_step / world_size) + if per_rank_gradient_steps > 0: + local_data = rb.sample_tensors( + cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=fabric.device, + from_numpy=cfg.buffer.from_numpy, ) - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount_task", actor_task.expl_amount) - aggregator.update("Params/exploration_amount_exploration", actor_exploration.expl_amount) + # Start training + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + for i in range(per_rank_gradient_steps): + if ( + cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq + == 0 + ): + tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau + for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): + tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) + for k in critics_exploration.keys(): + for cp, tcp in zip( + critics_exploration[k]["module"].module.parameters(), + critics_exploration[k]["target_module"].parameters(), + ): + tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) + batch = {k: v[i].float() for k, v in local_data.items()} + train( + fabric, + world_model, + actor_task, + critic_task, + target_critic_task, + world_optimizer, + actor_task_optimizer, + critic_task_optimizer, + batch, + aggregator, + cfg, + ensembles=ensembles, + ensemble_optimizer=ensemble_optimizer, + actor_exploration=actor_exploration, + critics_exploration=critics_exploration, + actor_exploration_optimizer=actor_exploration_optimizer, + is_continuous=is_continuous, + actions_dim=actions_dim, + moments_exploration=moments_exploration, + moments_task=moments_task, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -1016,6 +990,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -1061,7 +1040,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), "ensemble_optimizer": ensemble_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index b5b482f8..97df1ce6 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -22,7 +22,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay, save_configs +from sheeprl.utils.utils import Ratio, save_configs @register_algorithm() @@ -215,7 +215,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): rb = state["rb"] else: raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") - expl_decay_steps = state["expl_decay_steps"] if resume_from_checkpoint else 0 # Global variables train_step = 0 @@ -232,27 +231,18 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): last_log = state["last_log"] if resume_from_checkpoint else 0 last_checkpoint = state["last_checkpoint"] if resume_from_checkpoint else 0 policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) - updates_before_training = cfg.algo.train_every // policy_steps_per_update num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - max_step_expl_decay = cfg.algo.actor.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) if resume_from_checkpoint: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -279,7 +269,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): step_data["is_first"] = np.ones_like(step_data["dones"]) player.init_states() - per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -295,7 +285,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} if len(mask) == 0: mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, mask) + real_actions = actions = player.get_actions(preprocessed_obs, mask) actions = torch.cat(actions, -1).cpu().numpy() if is_continuous: real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() @@ -369,66 +359,50 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) - updates_before_training -= 1 - # Train the agent - if update >= learning_starts and updates_before_training <= 0: - if player.actor_type == "exploration": - player.actor_type = "task" - player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) - local_data = rb.sample_tensors( - cfg.algo.per_rank_batch_size, - sequence_length=cfg.algo.per_rank_sequence_length, - n_samples=( - cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps - ), - dtype=None, - device=fabric.device, - from_numpy=cfg.buffer.from_numpy, - ) - # Start training - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for i in range(next(iter(local_data.values())).shape[0]): - tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): - tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) - batch = {k: v[i].float() for k, v in local_data.items()} - train( - fabric, - world_model, - actor_task, - critic_task, - target_critic_task, - world_optimizer, - actor_task_optimizer, - critic_task_optimizer, - batch, - aggregator, - cfg, - is_continuous=is_continuous, - actions_dim=actions_dim, - moments=moments_task, - ) - train_step += world_size - updates_before_training = cfg.algo.train_every // policy_steps_per_update - if cfg.algo.actor.expl_decay: - expl_decay_steps += 1 - actor_task.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, + if update >= learning_starts: + per_rank_gradient_steps = ratio(policy_step / world_size) + if per_rank_gradient_steps > 0: + if player.actor_type != "task": + player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) + local_data = rb.sample_tensors( + cfg.algo.per_rank_batch_size, + sequence_length=cfg.algo.per_rank_sequence_length, + n_samples=per_rank_gradient_steps, + dtype=None, + device=fabric.device, + from_numpy=cfg.buffer.from_numpy, ) - actor_exploration.expl_amount = polynomial_decay( - expl_decay_steps, - initial=cfg.algo.actor.expl_amount, - final=cfg.algo.actor.expl_min, - max_decay_steps=max_step_expl_decay, - ) - if aggregator and not aggregator.disabled: - aggregator.update("Params/exploration_amount_task", actor_task.expl_amount) - aggregator.update("Params/exploration_amount_exploration", actor_exploration.expl_amount) + # Start training + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + for i in range(per_rank_gradient_steps): + if ( + cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq + == 0 + ): + tau = 1 if cumulative_per_rank_gradient_steps == 0 else cfg.algo.critic.tau + for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): + tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) + batch = {k: v[i].float() for k, v in local_data.items()} + train( + fabric, + world_model, + actor_task, + critic_task, + target_critic_task, + world_optimizer, + actor_task_optimizer, + critic_task_optimizer, + batch, + aggregator, + cfg, + is_continuous=is_continuous, + actions_dim=actions_dim, + moments=moments_task, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -438,6 +412,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -473,7 +452,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): "world_optimizer": world_optimizer.state_dict(), "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), - "expl_decay_steps": expl_decay_steps, + "ratio": ratio.state_dict(), "update": update * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), diff --git a/sheeprl/algos/p2e_dv3/utils.py b/sheeprl/algos/p2e_dv3/utils.py index c126e6c2..c2563336 100644 --- a/sheeprl/algos/p2e_dv3/utils.py +++ b/sheeprl/algos/p2e_dv3/utils.py @@ -28,8 +28,6 @@ "Loss/continue_loss", "Loss/ensemble_loss", "State/kl", - "Params/exploration_amount_task", - "Params/exploration_amount_exploration", "State/post_entropy", "State/prior_entropy", "Grads/world_model", diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 73f1ca43..58753f9f 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -27,7 +27,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import save_configs +from sheeprl.utils.utils import Ratio, save_configs def train( @@ -211,6 +211,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -232,6 +237,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = envs.reset(seed=cfg.seed)[0] obs = np.concatenate([obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) + per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * world_size @@ -283,58 +290,59 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Train the agent if update >= learning_starts: - training_steps = learning_starts if update == learning_starts else 1 - - # We sample one time to reduce the communications between processes - sample = rb.sample_tensors( - batch_size=training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size, - sample_next_obs=cfg.buffer.sample_next_obs, - dtype=None, - device=device, - from_numpy=cfg.buffer.from_numpy, - ) # [G*B] - gathered_data: Dict[str, torch.Tensor] = fabric.all_gather(sample) # [World, G*B] - for k, v in gathered_data.items(): - gathered_data[k] = v.float() # [G*B*World] - if fabric.world_size > 1: - gathered_data[k] = gathered_data[k].flatten(start_dim=0, end_dim=2) + per_rank_gradient_steps = ratio(policy_step / world_size) + if per_rank_gradient_steps > 0: + # We sample one time to reduce the communications between processes + sample = rb.sample_tensors( + batch_size=per_rank_gradient_steps * cfg.algo.per_rank_batch_size, + sample_next_obs=cfg.buffer.sample_next_obs, + dtype=None, + device=device, + from_numpy=cfg.buffer.from_numpy, + ) # [G*B] + gathered_data: Dict[str, torch.Tensor] = fabric.all_gather(sample) # [World, G*B] + for k, v in gathered_data.items(): + gathered_data[k] = v.float() # [G*B*World] + if fabric.world_size > 1: + gathered_data[k] = gathered_data[k].flatten(start_dim=0, end_dim=2) + else: + gathered_data[k] = gathered_data[k].flatten(start_dim=0, end_dim=1) + idxes_to_sample = list(range(next(iter(gathered_data.values())).shape[0])) + if world_size > 1: + dist_sampler: DistributedSampler = DistributedSampler( + idxes_to_sample, + num_replicas=world_size, + rank=fabric.global_rank, + shuffle=True, + seed=cfg.seed, + drop_last=False, + ) + sampler: BatchSampler = BatchSampler( + sampler=dist_sampler, batch_size=cfg.algo.per_rank_batch_size, drop_last=False + ) else: - gathered_data[k] = gathered_data[k].flatten(start_dim=0, end_dim=1) - idxes_to_sample = list(range(next(iter(gathered_data.values())).shape[0])) - if world_size > 1: - dist_sampler: DistributedSampler = DistributedSampler( - idxes_to_sample, - num_replicas=world_size, - rank=fabric.global_rank, - shuffle=True, - seed=cfg.seed, - drop_last=False, - ) - sampler: BatchSampler = BatchSampler( - sampler=dist_sampler, batch_size=cfg.algo.per_rank_batch_size, drop_last=False - ) - else: - sampler = BatchSampler( - sampler=idxes_to_sample, batch_size=cfg.algo.per_rank_batch_size, drop_last=False - ) - - # Start training - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for batch_idxes in sampler: - batch = {k: v[batch_idxes] for k, v in gathered_data.items()} - train( - fabric, - agent, - actor_optimizer, - qf_optimizer, - alpha_optimizer, - batch, - aggregator, - update, - cfg, - policy_steps_per_update, + sampler = BatchSampler( + sampler=idxes_to_sample, batch_size=cfg.algo.per_rank_batch_size, drop_last=False ) - train_step += world_size + + # Start training + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + for batch_idxes in sampler: + batch = {k: v[batch_idxes] for k, v in gathered_data.items()} + train( + fabric, + agent, + actor_optimizer, + qf_optimizer, + alpha_optimizer, + batch, + aggregator, + update, + cfg, + policy_steps_per_update, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size # Log metrics if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -344,6 +352,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -376,6 +389,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "qf_optimizer": qf_optimizer.state_dict(), "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), + "ratio": ratio.state_dict(), "update": update * fabric.world_size, "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 44a23569..cf75675e 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -26,7 +26,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import save_configs +from sheeprl.utils.utils import Ratio, save_configs @torch.inference_mode() @@ -149,6 +149,11 @@ def player( if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -170,6 +175,8 @@ def player( obs = envs.reset(seed=cfg.seed)[0] obs = np.concatenate([obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) + per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs @@ -220,48 +227,47 @@ def player( # Send data to the training agents if update >= learning_starts: - # Send local info to the trainers - if not first_info_sent: - world_collective.broadcast_object_list( - [{"update": update, "last_log": last_log, "last_checkpoint": last_checkpoint}], src=0 + per_rank_gradient_steps = ratio(policy_step / (fabric.world_size - 1)) + cumulative_per_rank_gradient_steps += per_rank_gradient_steps + if per_rank_gradient_steps > 0: + # Send local info to the trainers + if not first_info_sent: + world_collective.broadcast_object_list( + [{"update": update, "last_log": last_log, "last_checkpoint": last_checkpoint}], src=0 + ) + first_info_sent = True + + # Sample data to be sent to the trainers + sample = rb.sample_tensors( + batch_size=per_rank_gradient_steps * cfg.algo.per_rank_batch_size * (fabric.world_size - 1), + sample_next_obs=cfg.buffer.sample_next_obs, + dtype=None, + device=device, + from_numpy=cfg.buffer.from_numpy, ) - first_info_sent = True - - # Sample data to be sent to the trainers - training_steps = learning_starts if update == learning_starts else 1 - sample = rb.sample_tensors( - batch_size=training_steps - * cfg.algo.per_rank_gradient_steps - * cfg.algo.per_rank_batch_size - * (fabric.world_size - 1), - sample_next_obs=cfg.buffer.sample_next_obs, - dtype=None, - device=device, - from_numpy=cfg.buffer.from_numpy, - ) - # chunks = {k1: [k1_chunk_1, k1_chunk_2, ...], k2: [k2_chunk_1, k2_chunk_2, ...]} - chunks = { - k: v.float().split(training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size) - for k, v in sample.items() - } - # chunks = [{k1: k1_chunk_1, k2: k2_chunk_1}, {k1: k1_chunk_2, k2: k2_chunk_2}, ...] - chunks = [{k: v[i] for k, v in chunks.items()} for i in range(len(chunks[next(iter(chunks.keys()))]))] - world_collective.scatter_object_list([None], [None] + chunks, src=0) + # chunks = {k1: [k1_chunk_1, k1_chunk_2, ...], k2: [k2_chunk_1, k2_chunk_2, ...]} + chunks = { + k: v.float().split(per_rank_gradient_steps * cfg.algo.per_rank_batch_size) + for k, v in sample.items() + } + # chunks = [{k1: k1_chunk_1, k2: k2_chunk_1}, {k1: k1_chunk_2, k2: k2_chunk_2}, ...] + chunks = [{k: v[i] for k, v in chunks.items()} for i in range(len(chunks[next(iter(chunks.keys()))]))] + world_collective.scatter_object_list([None], [None] + chunks, src=0) - # Wait the trainers to finish - player_trainer_collective.broadcast(flattened_parameters, src=1) + # Wait the trainers to finish + player_trainer_collective.broadcast(flattened_parameters, src=1) - # Convert back the parameters - torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, actor.parameters()) + # Convert back the parameters + torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, actor.parameters()) - # Logs trainers-only metrics - if cfg.metric.log_level > 0 and policy_step - last_log >= cfg.metric.log_every: - # Gather metrics from the trainers - metrics = [None] - player_trainer_collective.broadcast_object_list(metrics, src=1) + # Logs trainers-only metrics + if cfg.metric.log_level > 0 and policy_step - last_log >= cfg.metric.log_every: + # Gather metrics from the trainers + metrics = [None] + player_trainer_collective.broadcast_object_list(metrics, src=1) - # Log metrics - fabric.log_dict(metrics[0], policy_step) + # Log metrics + fabric.log_dict(metrics[0], policy_step) # Logs player-only metrics if cfg.metric.log_level > 0 and policy_step - last_log >= cfg.metric.log_every: @@ -269,6 +275,13 @@ def player( fabric.log_dict(aggregator.compute(), policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", + cumulative_per_rank_gradient_steps * (fabric.world_size - 1) / policy_step, + policy_step, + ) + # Sync timers if not timer.disabled: timer_metrics = timer.compute() @@ -296,6 +309,7 @@ def player( player_trainer_collective=player_trainer_collective, ckpt_path=ckpt_path, replay_buffer=rb if cfg.buffer.checkpoint else None, + ratio_state_dict=ratio.state_dict(), ) world_collective.scatter_object_list([None], [None] + [-1] * (world_collective.world_size - 1), src=0) @@ -309,6 +323,7 @@ def player( player_trainer_collective=player_trainer_collective, ckpt_path=ckpt_path, replay_buffer=rb if cfg.buffer.checkpoint else None, + ratio_state_dict=ratio.state_dict(), ) envs.close() @@ -401,7 +416,7 @@ def trainer( if not MetricAggregator.disabled: aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device) - # Receive data from player reagrding the: + # Receive data from player regarding the: # * update # * last_log # * last_checkpoint diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 09c47b56..f5a80ed6 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -31,7 +31,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import save_configs +from sheeprl.utils.utils import Ratio, save_configs def train( @@ -46,16 +46,12 @@ def train( decoder_optimizer: Optimizer, data: Dict[str, Tensor], aggregator: MetricAggregator | None, - update: int, + cumulative_per_rank_gradient_steps: int, cfg: Dict[str, Any], - policy_steps_per_update: int, group: Optional[CollectibleGroup] = None, ): - critic_target_network_frequency = cfg.algo.critic.target_network_frequency // policy_steps_per_update + 1 - actor_network_frequency = cfg.algo.actor.network_frequency // policy_steps_per_update + 1 - decoder_update_freq = cfg.algo.decoder.update_freq // policy_steps_per_update + 1 - normalized_obs = {} normalized_next_obs = {} + normalized_obs = {} for k in cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder: if k in cfg.algo.cnn_keys.encoder: normalized_obs[k] = data[k] / 255.0 @@ -77,12 +73,12 @@ def train( aggregator.update("Loss/value_loss", qf_loss) # Update the target networks with EMA - if update % critic_target_network_frequency == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.critic.per_rank_target_network_update_freq == 0: agent.critic_target_ema() agent.critic_encoder_target_ema() # Update the actor - if update % actor_network_frequency == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.actor.per_rank_update_freq == 0: actions, logprobs = agent.get_actions_and_log_probs(normalized_obs, detach_encoder_features=True) qf_values = agent.get_q_values(normalized_obs, actions, detach_encoder_features=True) min_qf_values = torch.min(qf_values, dim=-1, keepdim=True)[0] @@ -103,7 +99,7 @@ def train( aggregator.update("Loss/alpha_loss", alpha_loss) # Update the decoder - if update % decoder_update_freq == 0: + if cumulative_per_rank_gradient_steps % cfg.algo.decoder.per_rank_update_freq == 0: hidden = encoder(normalized_obs) reconstruction = decoder(hidden) reconstruction_loss = 0 @@ -284,6 +280,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not cfg.buffer.checkpoint: learning_starts += start_step + # Create Ratio class + ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) + if cfg.checkpoint.resume_from: + ratio.load_state_dict(state["ratio"]) + # Warning for log and checkpoint every if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -307,13 +308,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if k in cfg.algo.cnn_keys.encoder: obs[k] = obs[k].reshape(cfg.env.num_envs, -1, *obs[k].shape[-2:]) + per_rank_gradient_steps = 0 + cumulative_per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): policy_step += cfg.env.num_envs * fabric.world_size # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - if update < learning_starts: + if update <= learning_starts: actions = envs.action_space.sample() else: with torch.inference_mode(): @@ -363,56 +366,56 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = next_obs # Train the agent - if update >= learning_starts - 1: - training_steps = learning_starts if update == learning_starts - 1 else 1 - - # We sample one time to reduce the communications between processes - sample = rb.sample_tensors( - training_steps * cfg.algo.per_rank_gradient_steps * cfg.algo.per_rank_batch_size, - sample_next_obs=cfg.buffer.sample_next_obs, - from_numpy=cfg.buffer.from_numpy, - ) # [G*B, 1] - gathered_data = fabric.all_gather(sample) # [G*B, World, 1] - flatten_dim = 3 if fabric.world_size > 1 else 2 - gathered_data = {k: v.view(-1, *v.shape[flatten_dim:]) for k, v in gathered_data.items()} # [G*B*World] - len_data = len(gathered_data[next(iter(gathered_data.keys()))]) - if fabric.world_size > 1: - dist_sampler: DistributedSampler = DistributedSampler( - range(len_data), - num_replicas=fabric.world_size, - rank=fabric.global_rank, - shuffle=True, - seed=cfg.seed, - drop_last=False, - ) - sampler: BatchSampler = BatchSampler( - sampler=dist_sampler, batch_size=cfg.algo.per_rank_batch_size, drop_last=False - ) - else: - sampler = BatchSampler( - sampler=range(len_data), batch_size=cfg.algo.per_rank_batch_size, drop_last=False - ) - - # Start training - with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): - for batch_idxes in sampler: - train( - fabric, - agent, - encoder, - decoder, - actor_optimizer, - qf_optimizer, - alpha_optimizer, - encoder_optimizer, - decoder_optimizer, - {k: v[batch_idxes] for k, v in gathered_data.items()}, - aggregator, - update, - cfg, - policy_steps_per_update, + if update >= learning_starts: + per_rank_gradient_steps = ratio(policy_step / world_size) + if per_rank_gradient_steps > 0: + # We sample one time to reduce the communications between processes + sample = rb.sample_tensors( + per_rank_gradient_steps * cfg.algo.per_rank_batch_size, + sample_next_obs=cfg.buffer.sample_next_obs, + from_numpy=cfg.buffer.from_numpy, + ) # [1, G*B] + gathered_data: Dict[str, torch.Tensor] = fabric.all_gather(sample) # [World, 1, G*B] + for k, v in gathered_data.items(): + gathered_data[k] = v.flatten(start_dim=0, end_dim=2).float() # [G*B*World] + len_data = len(gathered_data[next(iter(gathered_data.keys()))]) + if fabric.world_size > 1: + dist_sampler: DistributedSampler = DistributedSampler( + range(len_data), + num_replicas=fabric.world_size, + rank=fabric.global_rank, + shuffle=True, + seed=cfg.seed, + drop_last=False, ) - train_step += world_size + sampler: BatchSampler = BatchSampler( + sampler=dist_sampler, batch_size=cfg.algo.per_rank_batch_size, drop_last=False + ) + else: + sampler = BatchSampler( + sampler=range(len_data), batch_size=cfg.algo.per_rank_batch_size, drop_last=False + ) + + # Start training + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): + for batch_idxes in sampler: + train( + fabric, + agent, + encoder, + decoder, + actor_optimizer, + qf_optimizer, + alpha_optimizer, + encoder_optimizer, + decoder_optimizer, + {k: v[batch_idxes] for k, v in gathered_data.items()}, + aggregator, + cumulative_per_rank_gradient_steps, + cfg, + ) + cumulative_per_rank_gradient_steps += 1 + train_step += world_size # Log metrics if cfg.metric.log_level and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): @@ -422,6 +425,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Log replay ratio + fabric.log( + "Params/replay_ratio", cumulative_per_rank_gradient_steps * world_size / policy_step, policy_step + ) + # Sync distributed timers if not timer.disabled: timer_metrics = timer.compute() @@ -458,6 +466,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "alpha_optimizer": alpha_optimizer.state_dict(), "encoder_optimizer": encoder_optimizer.state_dict(), "decoder_optimizer": decoder_optimizer.state_dict(), + "ratio": ratio.state_dict(), "update": update * fabric.world_size, "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, diff --git a/sheeprl/configs/algo/dreamer_v1.yaml b/sheeprl/configs/algo/dreamer_v1.yaml index 212668f5..2aaa7b40 100644 --- a/sheeprl/configs/algo/dreamer_v1.yaml +++ b/sheeprl/configs/algo/dreamer_v1.yaml @@ -11,9 +11,9 @@ horizon: 15 name: dreamer_v1 # Training recipe -train_every: 1000 +replay_ratio: 0.1 learning_starts: 5000 -per_rank_gradient_steps: 100 +per_rank_pretrain_steps: 0 per_rank_sequence_length: ??? # Encoder and decoder keys @@ -100,8 +100,7 @@ actor: clip_gradients: 100.0 expl_amount: 0.3 expl_min: 0.0 - expl_decay: False - max_step_expl_decay: 200000 + expl_decay: 0.0 # Actor optimizer optimizer: diff --git a/sheeprl/configs/algo/dreamer_v2.yaml b/sheeprl/configs/algo/dreamer_v2.yaml index 719d91ec..2261aaf8 100644 --- a/sheeprl/configs/algo/dreamer_v2.yaml +++ b/sheeprl/configs/algo/dreamer_v2.yaml @@ -11,9 +11,8 @@ lmbda: 0.95 horizon: 15 # Training recipe -train_every: 5 +replay_ratio: 0.2 learning_starts: 1000 -per_rank_gradient_steps: 1 per_rank_pretrain_steps: 100 per_rank_sequence_length: ??? @@ -111,10 +110,6 @@ actor: layer_norm: ${algo.layer_norm} dense_units: ${algo.dense_units} clip_gradients: 100.0 - expl_amount: 0.0 - expl_min: 0.0 - expl_decay: False - max_step_expl_decay: 0 # Actor optimizer optimizer: @@ -128,7 +123,7 @@ critic: mlp_layers: ${algo.mlp_layers} layer_norm: ${algo.layer_norm} dense_units: ${algo.dense_units} - target_network_update_freq: 100 + per_rank_target_network_update_freq: 100 clip_gradients: 100.0 # Critic optimizer diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index c033f1d7..704a4bfe 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -13,10 +13,9 @@ lmbda: 0.95 horizon: 15 # Training recipe -train_every: 16 -learning_starts: 65536 -per_rank_pretrain_steps: 1 -per_rank_gradient_steps: 1 +replay_ratio: 1 +learning_starts: 1024 +per_rank_pretrain_steps: 0 per_rank_sequence_length: ??? # Encoder and decoder keys @@ -116,10 +115,6 @@ actor: layer_norm: ${algo.layer_norm} dense_units: ${algo.dense_units} clip_gradients: 100.0 - expl_amount: 0.0 - expl_min: 0.0 - expl_decay: False - max_step_expl_decay: 0 # Disttributed percentile model (used to scale the values) moments: @@ -141,7 +136,7 @@ critic: mlp_layers: ${algo.mlp_layers} layer_norm: ${algo.layer_norm} dense_units: ${algo.dense_units} - target_network_update_freq: 1 + per_rank_target_network_update_freq: 1 tau: 0.02 bins: 255 clip_gradients: 100.0 diff --git a/sheeprl/configs/algo/droq.yaml b/sheeprl/configs/algo/droq.yaml index 82a88b2a..29d0aff1 100644 --- a/sheeprl/configs/algo/droq.yaml +++ b/sheeprl/configs/algo/droq.yaml @@ -5,7 +5,7 @@ defaults: name: droq # Training recipe -per_rank_gradient_steps: 20 +replay_ratio: 20.0 # Override from `sac` config critic: diff --git a/sheeprl/configs/algo/sac.yaml b/sheeprl/configs/algo/sac.yaml index 452f447e..bc730cd2 100644 --- a/sheeprl/configs/algo/sac.yaml +++ b/sheeprl/configs/algo/sac.yaml @@ -11,8 +11,9 @@ gamma: 0.99 hidden_size: 256 # Training recipe +replay_ratio: 1.0 learning_starts: 100 -per_rank_gradient_steps: 1 +per_rank_pretrain_steps: 0 # Model related parameters # Actor diff --git a/sheeprl/configs/algo/sac_ae.yaml b/sheeprl/configs/algo/sac_ae.yaml index e7dfd94b..0111684e 100644 --- a/sheeprl/configs/algo/sac_ae.yaml +++ b/sheeprl/configs/algo/sac_ae.yaml @@ -7,7 +7,9 @@ defaults: name: sac_ae # Training recipe +replay_ratio: 1.0 learning_starts: 1000 +per_rank_pretrain_steps: 0 # Model related parameters cnn_channels_multiplier: 16 @@ -38,7 +40,7 @@ encoder: # Decoder decoder: l2_lambda: 1e-6 - update_freq: 1 + per_rank_update_freq: 1 cnn_channels_multiplier: ${algo.cnn_channels_multiplier} dense_units: ${algo.dense_units} mlp_layers: ${algo.mlp_layers} @@ -53,12 +55,12 @@ decoder: tau: 0.01 hidden_size: 1024 actor: - network_frequency: 2 + per_rank_update_freq: 2 optimizer: lr: 1e-3 eps: 1e-08 critic: - target_network_frequency: 2 + per_rank_target_network_update_freq: 2 optimizer: lr: 1e-3 eps: 1e-08 diff --git a/sheeprl/configs/exp/dreamer_v1_benchmarks.yaml b/sheeprl/configs/exp/dreamer_v1_benchmarks.yaml index efa04f19..12f29b1f 100644 --- a/sheeprl/configs/exp/dreamer_v1_benchmarks.yaml +++ b/sheeprl/configs/exp/dreamer_v1_benchmarks.yaml @@ -26,10 +26,10 @@ buffer: # Algorithm algo: learning_starts: 1024 - train_every: 16 + dense_units: 8 mlp_layers: 1 - per_rank_gradient_steps: 1 + world_model: stochastic_size: 4 encoder: diff --git a/sheeprl/configs/exp/dreamer_v2.yaml b/sheeprl/configs/exp/dreamer_v2.yaml index 66faf0c9..5565d62d 100644 --- a/sheeprl/configs/exp/dreamer_v2.yaml +++ b/sheeprl/configs/exp/dreamer_v2.yaml @@ -63,9 +63,6 @@ metric: State/kl: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} Grads/world_model: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} diff --git a/sheeprl/configs/exp/dreamer_v2_benchmarks.yaml b/sheeprl/configs/exp/dreamer_v2_benchmarks.yaml index 27bcb515..cfa2977a 100644 --- a/sheeprl/configs/exp/dreamer_v2_benchmarks.yaml +++ b/sheeprl/configs/exp/dreamer_v2_benchmarks.yaml @@ -27,7 +27,7 @@ buffer: algo: learning_starts: 1024 per_rank_pretrain_steps: 1 - train_every: 16 + dense_units: 8 mlp_layers: world_model: diff --git a/sheeprl/configs/exp/dreamer_v2_crafter.yaml b/sheeprl/configs/exp/dreamer_v2_crafter.yaml index db7e5249..30b80209 100644 --- a/sheeprl/configs/exp/dreamer_v2_crafter.yaml +++ b/sheeprl/configs/exp/dreamer_v2_crafter.yaml @@ -40,7 +40,6 @@ mlp_keys: # Algorithm algo: gamma: 0.999 - train_every: 5 layer_norm: True learning_starts: 10000 per_rank_pretrain_steps: 1 diff --git a/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml b/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml index dc8b146b..2d88fdf6 100644 --- a/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml +++ b/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml @@ -26,12 +26,12 @@ buffer: # Algorithm algo: gamma: 0.995 - train_every: 16 + replay_ratio: 0.0625 total_steps: 200000000 learning_starts: 200000 per_rank_batch_size: 32 per_rank_pretrain_steps: 1 - per_rank_gradient_steps: 1 + world_model: use_continues: True kl_free_nats: 0.0 diff --git a/sheeprl/configs/exp/dreamer_v3.yaml b/sheeprl/configs/exp/dreamer_v3.yaml index 5907104d..fc51c1d2 100644 --- a/sheeprl/configs/exp/dreamer_v3.yaml +++ b/sheeprl/configs/exp/dreamer_v3.yaml @@ -8,6 +8,7 @@ defaults: # Algorithm algo: + replay_ratio: 1 total_steps: 5000000 per_rank_batch_size: 16 per_rank_sequence_length: 64 @@ -61,9 +62,6 @@ metric: State/prior_entropy: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} Grads/world_model: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} diff --git a/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml b/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml index 7479c7e3..0a8f9eda 100644 --- a/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml +++ b/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml @@ -30,6 +30,5 @@ buffer: # Algorithm algo: - train_every: 1 total_steps: 100000 learning_starts: 1024 \ No newline at end of file diff --git a/sheeprl/configs/exp/dreamer_v3_100k_ms_pacman.yaml b/sheeprl/configs/exp/dreamer_v3_100k_ms_pacman.yaml index ca440728..8c85d19e 100644 --- a/sheeprl/configs/exp/dreamer_v3_100k_ms_pacman.yaml +++ b/sheeprl/configs/exp/dreamer_v3_100k_ms_pacman.yaml @@ -26,6 +26,5 @@ buffer: # Algorithm algo: - learning_starts: 1024 total_steps: 100000 - train_every: 1 + learning_starts: 1024 \ No newline at end of file diff --git a/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml b/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml index 0539a325..cb1ac4c1 100644 --- a/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml +++ b/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml @@ -30,7 +30,7 @@ buffer: # Algorithm algo: learning_starts: 65536 - train_every: 8 + replay_ratio: 0.125 cnn_keys: encoder: - frame diff --git a/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml b/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml index 5c3636b4..a2c6b78f 100644 --- a/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml +++ b/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml @@ -38,7 +38,7 @@ algo: total_steps: 10000000 per_rank_batch_size: 8 learning_starts: 65536 - train_every: 8 + replay_ratio: 0.125 cnn_keys: encoder: - frame diff --git a/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml b/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml index 5f4e2f8c..23a762c3 100644 --- a/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml +++ b/sheeprl/configs/exp/dreamer_v3_L_navigate.yaml @@ -28,7 +28,7 @@ buffer: # Algorithm algo: - train_every: 16 + replay_ratio: 0.015625 learning_starts: 65536 cnn_keys: encoder: diff --git a/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml b/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml index 8f0a136f..666f93dc 100644 --- a/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml +++ b/sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml @@ -25,8 +25,8 @@ buffer: # Algorithm algo: - train_every: 2 - learning_starts: 1024 + + replay_ratio: 0.5 cnn_keys: encoder: - rgb diff --git a/sheeprl/configs/exp/dreamer_v3_benchmarks.yaml b/sheeprl/configs/exp/dreamer_v3_benchmarks.yaml index b0b83b17..e10dfd96 100644 --- a/sheeprl/configs/exp/dreamer_v3_benchmarks.yaml +++ b/sheeprl/configs/exp/dreamer_v3_benchmarks.yaml @@ -26,7 +26,7 @@ buffer: # Algorithm algo: learning_starts: 1024 - train_every: 16 + replay_ratio: 1 dense_units: 8 mlp_layers: 1 world_model: diff --git a/sheeprl/configs/exp/dreamer_v3_dmc_cartpole_swingup_sparse.yaml b/sheeprl/configs/exp/dreamer_v3_dmc_cartpole_swingup_sparse.yaml index 7003be3a..704a07c5 100644 --- a/sheeprl/configs/exp/dreamer_v3_dmc_cartpole_swingup_sparse.yaml +++ b/sheeprl/configs/exp/dreamer_v3_dmc_cartpole_swingup_sparse.yaml @@ -40,7 +40,7 @@ algo: mlp_keys: encoder: [] learning_starts: 1024 - train_every: 2 + # Metric metric: diff --git a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml index 4c16b627..9964c3c3 100644 --- a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml +++ b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml @@ -38,7 +38,7 @@ algo: mlp_keys: encoder: [] learning_starts: 1024 - train_every: 2 + replay_ratio: 0.5 # Metric metric: diff --git a/sheeprl/configs/exp/dreamer_v3_super_mario_bros.yaml b/sheeprl/configs/exp/dreamer_v3_super_mario_bros.yaml index 1c5d8546..2c219281 100644 --- a/sheeprl/configs/exp/dreamer_v3_super_mario_bros.yaml +++ b/sheeprl/configs/exp/dreamer_v3_super_mario_bros.yaml @@ -28,7 +28,7 @@ algo: mlp_keys: encoder: [] learning_starts: 16384 - train_every: 4 + replay_ratio: 0.25 # Metric metric: diff --git a/sheeprl/configs/exp/p2e_dv2_exploration.yaml b/sheeprl/configs/exp/p2e_dv2_exploration.yaml index bae53323..3c33a758 100644 --- a/sheeprl/configs/exp/p2e_dv2_exploration.yaml +++ b/sheeprl/configs/exp/p2e_dv2_exploration.yaml @@ -51,12 +51,6 @@ metric: State/prior_entropy: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_task: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_exploration: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} Rewards/intrinsic: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} diff --git a/sheeprl/configs/exp/p2e_dv2_finetuning.yaml b/sheeprl/configs/exp/p2e_dv2_finetuning.yaml index 1d315969..e55ca8b2 100644 --- a/sheeprl/configs/exp/p2e_dv2_finetuning.yaml +++ b/sheeprl/configs/exp/p2e_dv2_finetuning.yaml @@ -52,12 +52,6 @@ metric: State/kl: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_task: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_exploration: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} Grads/world_model: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} diff --git a/sheeprl/configs/exp/p2e_dv3_expl_L_doapp_128px_gray_combo_discrete_15Mexpl_20Mstps.yaml b/sheeprl/configs/exp/p2e_dv3_expl_L_doapp_128px_gray_combo_discrete_15Mexpl_20Mstps.yaml index df1f356b..6eda91af 100644 --- a/sheeprl/configs/exp/p2e_dv3_expl_L_doapp_128px_gray_combo_discrete_15Mexpl_20Mstps.yaml +++ b/sheeprl/configs/exp/p2e_dv3_expl_L_doapp_128px_gray_combo_discrete_15Mexpl_20Mstps.yaml @@ -35,7 +35,7 @@ buffer: # Algorithm algo: learning_starts: 131072 - train_every: 1 + dense_units: 768 mlp_layers: 4 world_model: diff --git a/sheeprl/configs/exp/p2e_dv3_exploration.yaml b/sheeprl/configs/exp/p2e_dv3_exploration.yaml index 009c2d99..66b475a3 100644 --- a/sheeprl/configs/exp/p2e_dv3_exploration.yaml +++ b/sheeprl/configs/exp/p2e_dv3_exploration.yaml @@ -48,12 +48,6 @@ metric: State/kl: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_task: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_exploration: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} State/post_entropy: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} diff --git a/sheeprl/configs/exp/p2e_dv3_finetuning.yaml b/sheeprl/configs/exp/p2e_dv3_finetuning.yaml index 3d67448a..d03b1495 100644 --- a/sheeprl/configs/exp/p2e_dv3_finetuning.yaml +++ b/sheeprl/configs/exp/p2e_dv3_finetuning.yaml @@ -8,7 +8,7 @@ defaults: algo: name: p2e_dv3_finetuning - learning_starts: 65536 + learning_starts: 16384 total_steps: 1000000 player: actor_type: exploration @@ -46,12 +46,6 @@ metric: State/kl: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_task: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} - Params/exploration_amount_exploration: - _target_: torchmetrics.MeanMetric - sync_on_compute: ${metric.sync_on_compute} State/post_entropy: _target_: torchmetrics.MeanMetric sync_on_compute: ${metric.sync_on_compute} diff --git a/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml b/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml index 8dcad491..94d4d95a 100644 --- a/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml +++ b/sheeprl/configs/exp/p2e_dv3_fntn_L_doapp_64px_gray_combo_discrete_5Mstps.yaml @@ -37,7 +37,7 @@ buffer: # Algorithm algo: learning_starts: 65536 - train_every: 1 + dense_units: 768 mlp_layers: 4 world_model: diff --git a/sheeprl/configs/exp/sac_benchmarks.yaml b/sheeprl/configs/exp/sac_benchmarks.yaml index 63dc2086..b3ce9a7d 100644 --- a/sheeprl/configs/exp/sac_benchmarks.yaml +++ b/sheeprl/configs/exp/sac_benchmarks.yaml @@ -15,7 +15,7 @@ env: algo: name: sac learning_starts: 100 - per_rank_gradient_steps: 1 + per_rank_batch_size: 512 # # If you want to run this benchmark with older versions, # you need to comment the test function in the `./sheeprl/algos/ppo/ppo.py` file. diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index 73672e39..bbf10d5a 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -317,7 +317,10 @@ def sample_tensors( Dict[str, Tensor]: the sampled dictionary, containing the sampled array, one for every key, with a shape of [n_samples, batch_size, ...] """ - samples = self.sample(batch_size=batch_size, sample_next_obs=sample_next_obs, clone=clone, **kwargs) + n_samples = kwargs.pop("n_samples", 1) + samples = self.sample( + batch_size=batch_size, sample_next_obs=sample_next_obs, clone=clone, n_samples=n_samples, **kwargs + ) return { k: get_tensor(v, dtype=dtype, clone=clone, device=device, from_numpy=from_numpy) for k, v in samples.items() } diff --git a/sheeprl/utils/callback.py b/sheeprl/utils/callback.py index 577b1561..eaa099c6 100644 --- a/sheeprl/utils/callback.py +++ b/sheeprl/utils/callback.py @@ -61,6 +61,7 @@ def on_checkpoint_player( player_trainer_collective: TorchCollective, ckpt_path: str, replay_buffer: Optional["ReplayBuffer"] = None, + ratio_state_dict: Dict[str, Any] | None = None, ): state = [None] player_trainer_collective.broadcast_object_list(state, src=1) @@ -68,6 +69,8 @@ def on_checkpoint_player( if replay_buffer is not None: rb_state = self._ckpt_rb(replay_buffer) state["rb"] = replay_buffer + if ratio_state_dict is not None: + state["ratio"] = ratio_state_dict fabric.save(ckpt_path, state) if replay_buffer is not None: self._experiment_consistent_rb(replay_buffer, rb_state) diff --git a/sheeprl/utils/utils.py b/sheeprl/utils/utils.py index d85eb16a..7e350983 100644 --- a/sheeprl/utils/utils.py +++ b/sheeprl/utils/utils.py @@ -2,7 +2,8 @@ import copy import os -from typing import Any, Dict, Optional, Sequence, Tuple, Union +import warnings +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union import numpy as np import rich.syntax @@ -195,3 +196,47 @@ def unwrap_fabric(model: _FabricModule | nn.Module) -> nn.Module: def save_configs(cfg: dotdict, log_dir: str): OmegaConf.save(cfg.as_dict(), os.path.join(log_dir, "config.yaml"), resolve=True) + + +class Ratio: + """Directly taken from Hafner et al. (2023) implementation: + https://github.com/danijar/dreamerv3/blob/8fa35f83eee1ce7e10f3dee0b766587d0a713a60/dreamerv3/embodied/core/when.py#L26 + """ + + def __init__(self, ratio: float, pretrain_steps: int = 0): + if pretrain_steps < 0: + raise ValueError(f"'pretrain_steps' must be non-negative, got {pretrain_steps}") + if ratio < 0: + raise ValueError(f"'ratio' must be non-negative, got {ratio}") + self._pretrain_steps = pretrain_steps + self._ratio = ratio + self._prev = None + + def __call__(self, step: int) -> int: + if self._ratio == 0: + return 0 + if self._prev is None: + self._prev = step + repeats = 1 + if self._pretrain_steps > 0: + if step < self._pretrain_steps: + warnings.warn( + "The number of pretrain steps is greater than the number of current steps. This could lead to " + f"a higher ratio than the one specified ({self._ratio}). Setting the 'pretrain_steps' equal to " + "the number of current steps." + ) + self._pretrain_steps = step + repeats = round(self._pretrain_steps * self._ratio) + return repeats + repeats = round((step - self._prev) * self._ratio) + self._prev += repeats / self._ratio + return repeats + + def state_dict(self) -> Dict[str, Any]: + return {"_ratio": self._ratio, "_prev": self._prev, "_pretrain_steps": self._pretrain_steps} + + def load_state_dict(self, state_dict: Mapping[str, Any]): + self._ratio = state_dict["_ratio"] + self._prev = state_dict["_prev"] + self._pretrain_steps = state_dict["_pretrain_steps"] + return self diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index 3d6fc985..132d7a9d 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -69,7 +69,7 @@ def test_droq(standard_args, start_time): "algo.per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "algo.replay_ratio=1", f"root_dir={root_dir}", f"run_name={run_name}", ] @@ -87,7 +87,7 @@ def test_sac(standard_args, start_time): "algo.per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "algo.replay_ratio=1", f"root_dir={root_dir}", f"run_name={run_name}", ] @@ -105,7 +105,7 @@ def test_sac_ae(standard_args, start_time): "algo.per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "algo.replay_ratio=1", f"root_dir={root_dir}", f"run_name={run_name}", "algo.mlp_keys.encoder=[state]", @@ -114,8 +114,8 @@ def test_sac_ae(standard_args, start_time): "algo.hidden_size=4", "algo.dense_units=4", "algo.cnn_channels_multiplier=2", - "algo.actor.network_frequency=1", - "algo.decoder.update_freq=1", + "algo.actor.per_rank_update_freq=1", + "algo.decoder.per_rank_update_freq=1", ] with mock.patch.object(sys, "argv", args): @@ -130,7 +130,7 @@ def test_sac_decoupled(standard_args, start_time): "exp=sac_decoupled", "algo.per_rank_batch_size=1", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "algo.replay_ratio=1", f"fabric.devices={os.environ['LT_DEVICES']}", f"root_dir={root_dir}", f"run_name={run_name}", @@ -239,7 +239,7 @@ def test_dreamer_v1(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "algo.replay_ratio=1", "algo.horizon=2", f"env.id={env_id}", f"root_dir={root_dir}", @@ -270,7 +270,7 @@ def test_p2e_dv1(standard_args, env_id, start_time): "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "algo.replay_ratio=1", "algo.horizon=4", "env.id=" + env_id, f"root_dir={root_dir}", @@ -311,7 +311,7 @@ def test_p2e_dv1(standard_args, env_id, start_time): "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "algo.replay_ratio=1", "algo.horizon=4", "env=dummy", "env.id=" + env_id, @@ -341,7 +341,7 @@ def test_dreamer_v2(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "algo.replay_ratio=1", "algo.horizon=8", "env.id=" + env_id, f"root_dir={root_dir}", @@ -377,7 +377,7 @@ def test_p2e_dv2(standard_args, env_id, start_time): "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "algo.replay_ratio=1", "algo.horizon=4", "env.id=" + env_id, f"root_dir={root_dir}", @@ -418,7 +418,7 @@ def test_p2e_dv2(standard_args, env_id, start_time): "algo.per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "algo.replay_ratio=1", "algo.horizon=4", "env=dummy", "env.id=" + env_id, @@ -448,7 +448,7 @@ def test_dreamer_v3(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "algo.replay_ratio=1", "algo.horizon=8", "env.id=" + env_id, f"root_dir={root_dir}", @@ -459,7 +459,6 @@ def test_dreamer_v3(standard_args, env_id, start_time): "algo.world_model.representation_model.hidden_size=8", "algo.world_model.transition_model.hidden_size=8", "algo.layer_norm=True", - "algo.train_every=1", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", ] @@ -483,7 +482,7 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "algo.replay_ratio=1", "algo.horizon=8", "env.id=" + env_id, f"root_dir={root_dir}", @@ -494,7 +493,6 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.world_model.representation_model.hidden_size=8", "algo.world_model.transition_model.hidden_size=8", "algo.layer_norm=True", - "algo.train_every=1", "buffer.checkpoint=True", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", @@ -526,7 +524,7 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.per_rank_gradient_steps=1", + "algo.replay_ratio=1", "algo.horizon=8", "env=dummy", "env.id=" + env_id, @@ -538,7 +536,6 @@ def test_p2e_dv3(standard_args, env_id, start_time): "algo.world_model.representation_model.hidden_size=8", "algo.world_model.transition_model.hidden_size=8", "algo.layer_norm=True", - "algo.train_every=1", "algo.cnn_keys.encoder=[rgb]", "algo.cnn_keys.decoder=[rgb]", ] diff --git a/tests/test_algos/test_cli.py b/tests/test_algos/test_cli.py index 701a5e7a..0e95871f 100644 --- a/tests/test_algos/test_cli.py +++ b/tests/test_algos/test_cli.py @@ -125,12 +125,12 @@ def test_resume_from_checkpoint(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " - "algo.world_model.encoder.cnn_channels_multiplier=2 algo.per_rank_gradient_steps=1 " + "algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " - f"algo.train_every=1 root_dir={root_dir} run_name={run_name} " + f"root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, check=True, @@ -168,12 +168,12 @@ def test_resume_from_checkpoint_env_error(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " - "algo.world_model.encoder.cnn_channels_multiplier=2 algo.per_rank_gradient_steps=1 " + "algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " - f"algo.train_every=1 root_dir={root_dir} run_name={run_name} " + f"root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, check=True, @@ -221,12 +221,12 @@ def test_resume_from_checkpoint_algo_error(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " - "algo.world_model.encoder.cnn_channels_multiplier=2 algo.per_rank_gradient_steps=1 " + "algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " - f"algo.train_every=1 root_dir={root_dir} run_name={run_name} " + f"root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, check=True, @@ -276,12 +276,12 @@ def test_evaluate(): sys.executable + " sheeprl.py exp=dreamer_v3 env=dummy dry_run=True " "env.capture_video=False algo.dense_units=8 algo.horizon=8 " "algo.cnn_keys.encoder=[rgb] algo.cnn_keys.decoder=[rgb] " - "algo.world_model.encoder.cnn_channels_multiplier=2 algo.per_rank_gradient_steps=1 " + "algo.world_model.encoder.cnn_channels_multiplier=2 algo.replay_ratio=1 " "algo.world_model.recurrent_model.recurrent_state_size=8 " "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " "algo.layer_norm=True algo.per_rank_batch_size=1 algo.per_rank_sequence_length=1 " - f"algo.train_every=1 root_dir={root_dir} run_name={run_name} " + f"root_dir={root_dir} run_name={run_name} " "checkpoint.save_last=True metric.log_level=0 metric.disable_timer=True", shell=True, check=True,