Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/algos #120

Merged
merged 3 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,19 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):

# Sync distributed timers
timer_metrics = timer.compute()
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
if "Time/train_time" in timer_metrics:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
timer.reset()

# Reset counters
Expand Down
24 changes: 13 additions & 11 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,17 +745,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Sync distributed timers
timer_metrics = timer.compute()
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
if "Time/train_time" in timer_metrics:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
timer.reset()

# Reset counters
Expand Down
24 changes: 13 additions & 11 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,17 +828,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Sync distributed timers
timer_metrics = timer.compute()
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
if "Time/train_time" in timer_metrics:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
timer.reset()

# Reset counters
Expand Down
26 changes: 15 additions & 11 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
cfg.algo.actor.moments.percentile.low,
cfg.algo.actor.moments.percentile.high,
)
if cfg.checkpoint.resume_from:
moments.load_state_dict(state["moments"])

# Metrics
aggregator = MetricAggregator(
Expand Down Expand Up @@ -744,17 +746,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Sync distributed timers
timer_metrics = timer.compute()
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
if "Time/train_time" in timer_metrics:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
timer.reset()

# Reset counters
Expand Down
24 changes: 13 additions & 11 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,17 +366,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Sync distributed timers
timer_metrics = timer.compute()
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
if "Time/train_time" in timer_metrics:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
timer.reset()

# Reset counters
Expand Down
25 changes: 13 additions & 12 deletions sheeprl/algos/p2e_dv1/p2e_dv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

step_data["dones"] = dones
step_data["actions"] = actions
step_data["observations"] = obs
step_data["rewards"] = clip_rewards_fn(rewards)
rb.add(step_data[None, ...])

Expand Down Expand Up @@ -816,17 +815,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Sync distributed timers
timer_metrics = timer.compute()
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
if "Time/train_time" in timer_metrics:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
timer.reset()

# Reset counters
Expand Down
24 changes: 13 additions & 11 deletions sheeprl/algos/p2e_dv2/p2e_dv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,17 +1018,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Sync distributed timers
timer_metrics = timer.compute()
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
if "Time/train_time" in timer_metrics:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
timer.reset()

# Reset counters
Expand Down
26 changes: 14 additions & 12 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
for k, v in info["final_observation"][truncated_env].items():
torch_v = torch.as_tensor(v, dtype=torch.float32, device=device)
if k in cfg.cnn_keys.encoder:
torch_v = torch_v / 255.0
torch_v = torch_v.view(len(truncated_envs), -1, *torch_obs.shape[-2:]) / 255.0 - 0.5
real_next_obs[k][i] = torch_v
with torch.no_grad():
vals = agent.module.get_value(real_next_obs).cpu().numpy()
Expand Down Expand Up @@ -416,17 +416,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Sync distributed timers
timer_metrics = timer.compute()
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
if "Time/train_time" in timer_metrics:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
timer.reset()

# Reset counters
Expand Down
8 changes: 4 additions & 4 deletions sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def player(
with torch.no_grad():
# Sample an action given the observation received by the environment
normalized_obs = {
k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys
k: next_obs[k] / 255.0 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys
}
actions, logprobs, _, values = agent(normalized_obs)
if is_continuous:
Expand All @@ -225,7 +225,7 @@ def player(
for k, v in info["final_observation"][truncated_env].items():
torch_v = torch.as_tensor(v, dtype=torch.float32, device=device)
if k in cfg.cnn_keys.encoder:
torch_v = torch_v / 255.0
torch_v = torch_v.view(len(truncated_envs), -1, *torch_obs.shape[-2:]) / 255.0 - 0.5
real_next_obs[k][i] = torch_v
with torch.no_grad():
vals = agent.module.get_value(real_next_obs).cpu().numpy()
Expand Down Expand Up @@ -268,7 +268,7 @@ def player(
fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}")

# Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
normalized_obs = {k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys}
normalized_obs = {k: next_obs[k] / 255.0 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys}
next_values = agent.get_value(normalized_obs)
returns, advantages = gae(
rb["rewards"],
Expand Down Expand Up @@ -474,7 +474,7 @@ def trainer(
for batch_idxes in sampler:
batch = data[batch_idxes]
normalized_obs = {
k: batch[k] / 255 - 0.5 if k in agent.feature_extractor.cnn_keys else batch[k]
k: batch[k] / 255.0 - 0.5 if k in agent.feature_extractor.cnn_keys else batch[k]
for k in cfg.cnn_keys.encoder + cfg.mlp_keys.encoder
}
_, logprobs, entropy, new_values = agent(
Expand Down
34 changes: 19 additions & 15 deletions sheeprl/algos/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def train(
batch = data[:, idxes]
mask = batch["mask"].unsqueeze(-1)
for k in cfg.cnn_keys.encoder:
batch[k] = batch[k] / 255.0
batch[k] = batch[k] / 255.0 - 0.5

_, logprobs, entropies, values, _ = agent(
{k: batch[k] for k in set(cfg.cnn_keys.encoder + cfg.mlp_keys.encoder)},
Expand Down Expand Up @@ -297,7 +297,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
with torch.no_grad():
# Sample an action given the observation received by the environment
normalized_obs = {
k: obs[k][None] / 255.0 if k in cfg.cnn_keys.encoder else obs[k][None] for k in obs_keys
k: obs[k][None] / 255.0 - 0.5 if k in cfg.cnn_keys.encoder else obs[k][None] for k in obs_keys
} # [Seq_len, Batch_size, D] --> [1, num_envs, D]
actions, logprobs, _, values, states = agent.module(
normalized_obs, prev_actions=prev_actions, prev_states=prev_states
Expand Down Expand Up @@ -326,7 +326,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
for k, v in info["final_observation"][truncated_env].items():
torch_v = torch.as_tensor(v, dtype=torch.float32, device=device)
if k in cfg.cnn_keys.encoder:
torch_v = torch_v / 255.0
torch_v = torch_v.view(1, len(truncated_envs), -1, *torch_obs.shape[-2:]) / 255.0 - 0.5
real_next_obs[k][0, i] = torch_v
with torch.no_grad():
feat = agent.module.feature_extractor(real_next_obs)
Expand Down Expand Up @@ -386,7 +386,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
with torch.no_grad():
normalized_obs = {k: obs[k][None] / 255.0 if k in cfg.cnn_keys.encoder else obs[k][None] for k in obs_keys}
normalized_obs = {
k: obs[k][None] / 255.0 - 0.5 if k in cfg.cnn_keys.encoder else obs[k][None] for k in obs_keys
}
feat = agent.module.feature_extractor(normalized_obs)
rnn_out, _ = agent.module.rnn(torch.cat((feat, actions), dim=-1), states)
next_values = agent.module.get_values(rnn_out)
Expand Down Expand Up @@ -462,17 +464,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Sync distributed timers
timer_metrics = timer.compute()
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
if "Time/train_time" in timer_metrics:
fabric.log(
"Time/sps_train",
(train_step - last_train) / timer_metrics["Time/train_time"],
policy_step,
)
if "Time/env_interaction_time" in timer_metrics:
fabric.log(
"Time/sps_env_interaction",
((policy_step - last_log) / world_size * cfg.env.action_repeat)
/ timer_metrics["Time/env_interaction_time"],
policy_step,
)
timer.reset()

# Reset counters
Expand Down
Loading