From 48d20d3ace28cd01207b66377bd0f87fa3c7b39a Mon Sep 17 00:00:00 2001 From: belerico Date: Tue, 10 Oct 2023 11:54:04 +0200 Subject: [PATCH 1/3] Fix metrics when they are nan --- sheeprl/utils/metric.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sheeprl/utils/metric.py b/sheeprl/utils/metric.py index c1b0b6f2..28fc77cc 100644 --- a/sheeprl/utils/metric.py +++ b/sheeprl/utils/metric.py @@ -1,4 +1,5 @@ import warnings +from math import isnan from typing import Any, Dict, List, Optional, Union import torch @@ -105,6 +106,12 @@ def compute(self) -> Dict[str, List]: category=RuntimeWarning, ) reduced_metrics[k] = reduced + + is_tensor = torch.is_tensor(reduced_metrics[k]) + if (is_tensor and torch.isnan(reduced_metrics[k]).any()) or ( + not is_tensor and isnan(reduced_metrics[k]) + ): + reduced_metrics.pop(k, None) return reduced_metrics From 82b576a13fcd0e8d96d1981d0d96e75ee474ce25 Mon Sep 17 00:00:00 2001 From: belerico Date: Tue, 10 Oct 2023 11:57:58 +0200 Subject: [PATCH 2/3] Fix reshape when bootstrapping + fix normalization --- sheeprl/algos/ppo/ppo.py | 26 ++++++++------- sheeprl/algos/ppo/ppo_decoupled.py | 8 ++--- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 34 +++++++++++--------- 3 files changed, 37 insertions(+), 31 deletions(-) diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index a903ecf0..4b472874 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -313,7 +313,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for k, v in info["final_observation"][truncated_env].items(): torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) if k in cfg.cnn_keys.encoder: - torch_v = torch_v / 255.0 + torch_v = torch_v.view(len(truncated_envs), -1, *torch_obs.shape[-2:]) / 255.0 - 0.5 real_next_obs[k][i] = torch_v with torch.no_grad(): vals = agent.module.get_value(real_next_obs).cpu().numpy() @@ -416,17 +416,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sync distributed timers timer_metrics = timer.compute() - fabric.log( - "Time/sps_train", - (train_step - last_train) / timer_metrics["Time/train_time"], - policy_step, - ) - fabric.log( - "Time/sps_env_interaction", - ((policy_step - last_log) / world_size * cfg.env.action_repeat) - / timer_metrics["Time/env_interaction_time"], - policy_step, - ) + if "Time/train_time" in timer_metrics: + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + if "Time/env_interaction_time" in timer_metrics: + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) timer.reset() # Reset counters diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index ce653cb9..96b4fdce 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -199,7 +199,7 @@ def player( with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = { - k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys + k: next_obs[k] / 255.0 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys } actions, logprobs, _, values = agent(normalized_obs) if is_continuous: @@ -225,7 +225,7 @@ def player( for k, v in info["final_observation"][truncated_env].items(): torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) if k in cfg.cnn_keys.encoder: - torch_v = torch_v / 255.0 + torch_v = torch_v.view(len(truncated_envs), -1, *torch_obs.shape[-2:]) / 255.0 - 0.5 real_next_obs[k][i] = torch_v with torch.no_grad(): vals = agent.module.get_value(real_next_obs).cpu().numpy() @@ -268,7 +268,7 @@ def player( fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) - normalized_obs = {k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys} + normalized_obs = {k: next_obs[k] / 255.0 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys} next_values = agent.get_value(normalized_obs) returns, advantages = gae( rb["rewards"], @@ -474,7 +474,7 @@ def trainer( for batch_idxes in sampler: batch = data[batch_idxes] normalized_obs = { - k: batch[k] / 255 - 0.5 if k in agent.feature_extractor.cnn_keys else batch[k] + k: batch[k] / 255.0 - 0.5 if k in agent.feature_extractor.cnn_keys else batch[k] for k in cfg.cnn_keys.encoder + cfg.mlp_keys.encoder } _, logprobs, entropy, new_values = agent( diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index af5e057e..e507cfaa 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -55,7 +55,7 @@ def train( batch = data[:, idxes] mask = batch["mask"].unsqueeze(-1) for k in cfg.cnn_keys.encoder: - batch[k] = batch[k] / 255.0 + batch[k] = batch[k] / 255.0 - 0.5 _, logprobs, entropies, values, _ = agent( {k: batch[k] for k in set(cfg.cnn_keys.encoder + cfg.mlp_keys.encoder)}, @@ -297,7 +297,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = { - k: obs[k][None] / 255.0 if k in cfg.cnn_keys.encoder else obs[k][None] for k in obs_keys + k: obs[k][None] / 255.0 - 0.5 if k in cfg.cnn_keys.encoder else obs[k][None] for k in obs_keys } # [Seq_len, Batch_size, D] --> [1, num_envs, D] actions, logprobs, _, values, states = agent.module( normalized_obs, prev_actions=prev_actions, prev_states=prev_states @@ -326,7 +326,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for k, v in info["final_observation"][truncated_env].items(): torch_v = torch.as_tensor(v, dtype=torch.float32, device=device) if k in cfg.cnn_keys.encoder: - torch_v = torch_v / 255.0 + torch_v = torch_v.view(1, len(truncated_envs), -1, *torch_obs.shape[-2:]) / 255.0 - 0.5 real_next_obs[k][0, i] = torch_v with torch.no_grad(): feat = agent.module.feature_extractor(real_next_obs) @@ -386,7 +386,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.no_grad(): - normalized_obs = {k: obs[k][None] / 255.0 if k in cfg.cnn_keys.encoder else obs[k][None] for k in obs_keys} + normalized_obs = { + k: obs[k][None] / 255.0 - 0.5 if k in cfg.cnn_keys.encoder else obs[k][None] for k in obs_keys + } feat = agent.module.feature_extractor(normalized_obs) rnn_out, _ = agent.module.rnn(torch.cat((feat, actions), dim=-1), states) next_values = agent.module.get_values(rnn_out) @@ -462,17 +464,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sync distributed timers timer_metrics = timer.compute() - fabric.log( - "Time/sps_train", - (train_step - last_train) / timer_metrics["Time/train_time"], - policy_step, - ) - fabric.log( - "Time/sps_env_interaction", - ((policy_step - last_log) / world_size * cfg.env.action_repeat) - / timer_metrics["Time/env_interaction_time"], - policy_step, - ) + if "Time/train_time" in timer_metrics: + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + if "Time/env_interaction_time" in timer_metrics: + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) timer.reset() # Reset counters From f029e1e8dd9afb86c44f700c9f55d66cd7aa75f5 Mon Sep 17 00:00:00 2001 From: belerico Date: Tue, 10 Oct 2023 11:58:29 +0200 Subject: [PATCH 3/3] Guard timer metrics --- howto/register_new_algorithm.md | 24 +++++++++++++----------- sheeprl/algos/dreamer_v1/dreamer_v1.py | 24 +++++++++++++----------- sheeprl/algos/dreamer_v2/dreamer_v2.py | 24 +++++++++++++----------- sheeprl/algos/dreamer_v3/dreamer_v3.py | 26 +++++++++++++++----------- sheeprl/algos/droq/droq.py | 24 +++++++++++++----------- sheeprl/algos/p2e_dv1/p2e_dv1.py | 25 +++++++++++++------------ sheeprl/algos/p2e_dv2/p2e_dv2.py | 24 +++++++++++++----------- sheeprl/algos/sac/sac.py | 24 +++++++++++++----------- sheeprl/algos/sac_ae/sac_ae.py | 24 +++++++++++++----------- 9 files changed, 119 insertions(+), 100 deletions(-) diff --git a/howto/register_new_algorithm.md b/howto/register_new_algorithm.md index 3cac48ed..ddd39a63 100644 --- a/howto/register_new_algorithm.md +++ b/howto/register_new_algorithm.md @@ -244,17 +244,19 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): # Sync distributed timers timer_metrics = timer.compute() - fabric.log( - "Time/sps_train", - (train_step - last_train) / timer_metrics["Time/train_time"], - policy_step, - ) - fabric.log( - "Time/sps_env_interaction", - ((policy_step - last_log) / world_size * cfg.env.action_repeat) - / timer_metrics["Time/env_interaction_time"], - policy_step, - ) + if "Time/train_time" in timer_metrics: + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + if "Time/env_interaction_time" in timer_metrics: + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) timer.reset() # Reset counters diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 8948d1ab..f3e040aa 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -745,17 +745,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sync distributed timers timer_metrics = timer.compute() - fabric.log( - "Time/sps_train", - (train_step - last_train) / timer_metrics["Time/train_time"], - policy_step, - ) - fabric.log( - "Time/sps_env_interaction", - ((policy_step - last_log) / world_size * cfg.env.action_repeat) - / timer_metrics["Time/env_interaction_time"], - policy_step, - ) + if "Time/train_time" in timer_metrics: + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + if "Time/env_interaction_time" in timer_metrics: + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) timer.reset() # Reset counters diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index a825163b..cd0105a0 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -828,17 +828,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sync distributed timers timer_metrics = timer.compute() - fabric.log( - "Time/sps_train", - (train_step - last_train) / timer_metrics["Time/train_time"], - policy_step, - ) - fabric.log( - "Time/sps_env_interaction", - ((policy_step - last_log) / world_size * cfg.env.action_repeat) - / timer_metrics["Time/env_interaction_time"], - policy_step, - ) + if "Time/train_time" in timer_metrics: + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + if "Time/env_interaction_time" in timer_metrics: + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) timer.reset() # Reset counters diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index d6d0eb9c..47daf2be 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -479,6 +479,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg.algo.actor.moments.percentile.low, cfg.algo.actor.moments.percentile.high, ) + if cfg.checkpoint.resume_from: + moments.load_state_dict(state["moments"]) # Metrics aggregator = MetricAggregator( @@ -744,17 +746,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sync distributed timers timer_metrics = timer.compute() - fabric.log( - "Time/sps_train", - (train_step - last_train) / timer_metrics["Time/train_time"], - policy_step, - ) - fabric.log( - "Time/sps_env_interaction", - ((policy_step - last_log) / world_size * cfg.env.action_repeat) - / timer_metrics["Time/env_interaction_time"], - policy_step, - ) + if "Time/train_time" in timer_metrics: + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + if "Time/env_interaction_time" in timer_metrics: + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) timer.reset() # Reset counters diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 4a91889c..e7fc8ce5 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -366,17 +366,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sync distributed timers timer_metrics = timer.compute() - fabric.log( - "Time/sps_train", - (train_step - last_train) / timer_metrics["Time/train_time"], - policy_step, - ) - fabric.log( - "Time/sps_env_interaction", - ((policy_step - last_log) / world_size * cfg.env.action_repeat) - / timer_metrics["Time/env_interaction_time"], - policy_step, - ) + if "Time/train_time" in timer_metrics: + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + if "Time/env_interaction_time" in timer_metrics: + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) timer.reset() # Reset counters diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index eb4f19bb..4fae178a 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -741,7 +741,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data["dones"] = dones step_data["actions"] = actions - step_data["observations"] = obs step_data["rewards"] = clip_rewards_fn(rewards) rb.add(step_data[None, ...]) @@ -816,17 +815,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sync distributed timers timer_metrics = timer.compute() - fabric.log( - "Time/sps_train", - (train_step - last_train) / timer_metrics["Time/train_time"], - policy_step, - ) - fabric.log( - "Time/sps_env_interaction", - ((policy_step - last_log) / world_size * cfg.env.action_repeat) - / timer_metrics["Time/env_interaction_time"], - policy_step, - ) + if "Time/train_time" in timer_metrics: + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + if "Time/env_interaction_time" in timer_metrics: + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) timer.reset() # Reset counters diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index c6aa39cb..1c955f1c 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -1018,17 +1018,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sync distributed timers timer_metrics = timer.compute() - fabric.log( - "Time/sps_train", - (train_step - last_train) / timer_metrics["Time/train_time"], - policy_step, - ) - fabric.log( - "Time/sps_env_interaction", - ((policy_step - last_log) / world_size * cfg.env.action_repeat) - / timer_metrics["Time/env_interaction_time"], - policy_step, - ) + if "Time/train_time" in timer_metrics: + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + if "Time/env_interaction_time" in timer_metrics: + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) timer.reset() # Reset counters diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index d727dfb0..3e98f391 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -356,17 +356,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sync distributed timers timer_metrics = timer.compute() - fabric.log( - "Time/sps_train", - (train_step - last_train) / timer_metrics["Time/train_time"], - policy_step, - ) - fabric.log( - "Time/sps_env_interaction", - ((policy_step - last_log) / world_size * cfg.env.action_repeat) - / timer_metrics["Time/env_interaction_time"], - policy_step, - ) + if "Time/train_time" in timer_metrics: + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + if "Time/env_interaction_time" in timer_metrics: + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) timer.reset() # Reset counters diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index ec839248..48e6679c 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -519,17 +519,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sync distributed timers timer_metrics = timer.compute() - fabric.log( - "Time/sps_train", - (train_step - last_train) / timer_metrics["Time/train_time"], - policy_step, - ) - fabric.log( - "Time/sps_env_interaction", - ((policy_step - last_log) / world_size * cfg.env.action_repeat) - / timer_metrics["Time/env_interaction_time"], - policy_step, - ) + if "Time/train_time" in timer_metrics: + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + if "Time/env_interaction_time" in timer_metrics: + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) timer.reset() # Reset counters