Skip to content

Commit

Permalink
Feature/optimizations (#177)
Browse files Browse the repository at this point in the history
* Do not create metrics when timer is disabled + do not cat when unneeded

* Add run test optionally

* Fix wrong import
  • Loading branch information
belerico authored Dec 22, 2023
1 parent 6c2e0b4 commit 955a57f
Show file tree
Hide file tree
Showing 23 changed files with 73 additions and 64 deletions.
4 changes: 2 additions & 2 deletions howto/register_external_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
# Sample an action given the observation received by the environment
normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys)
Expand Down Expand Up @@ -370,7 +370,7 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(agent.module, fabric, cfg, log_dir)

# Optional part in case you want to give the possibility to register your models with MLFlow
Expand Down
4 changes: 2 additions & 2 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
# Sample an action given the observation received by the environment
normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys)
Expand Down Expand Up @@ -367,7 +367,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(agent.module, fabric, cfg, log_dir)

# Optional part in case you want to give the possibility to register your models with MLFlow
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
Expand Down Expand Up @@ -681,7 +681,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Train the agent
if update > learning_starts and updates_before_training <= 0:
# Start training
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(cfg.algo.per_rank_gradient_steps):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
Expand Down Expand Up @@ -775,7 +775,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
Expand Down Expand Up @@ -735,7 +735,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device=fabric.device,
from_numpy=cfg.buffer.from_numpy,
)
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(next(iter(local_data.values())).shape[0]):
if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0:
for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()):
Expand Down Expand Up @@ -828,7 +828,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
Expand Down Expand Up @@ -678,7 +678,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device=fabric.device,
from_numpy=cfg.buffer.from_numpy,
)
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(next(iter(local_data.values())).shape[0]):
if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0:
tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau
Expand Down Expand Up @@ -775,7 +775,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir, sample_actions=True)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def train(
)
actor_data = {k: actor_data[k][next(iter(actor_sampler))] for k in actor_data.keys()}

with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
# Update the soft-critic
for batch_idxes in critic_sampler:
critic_batch_data = {k: critic_data[k][batch_idxes] for k in critic_data.keys()}
Expand Down Expand Up @@ -283,7 +283,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
# Sample an action given the observation received by the environment
actions, _ = agent.actor.module(torch.from_numpy(obs).to(device))
Expand Down Expand Up @@ -385,7 +385,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(agent.actor.module, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
Expand Down Expand Up @@ -721,7 +721,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Train the agent
if update >= learning_starts and updates_before_training <= 0:
# Start training
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(cfg.algo.per_rank_gradient_steps):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
Expand Down Expand Up @@ -835,7 +835,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
# task test zero-shot
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task.module
player.actor_type = "task"
test(player, fabric, cfg, log_dir, "zero-shot")
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
normalized_obs = {}
for k in obs_keys:
Expand Down Expand Up @@ -349,7 +349,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
if player.actor_type == "exploration":
player.actor = actor_task.module
player.actor_type = "task"
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(cfg.algo.per_rank_gradient_steps):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
Expand Down Expand Up @@ -452,7 +452,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

envs.close()
# task test few-shot
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task.module
player.actor_type = "task"
test(player, fabric, cfg, log_dir, "few-shot")
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
Expand Down Expand Up @@ -881,7 +881,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
from_numpy=cfg.buffer.from_numpy,
)
# Start training
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(next(iter(local_data.values())).shape[0]):
if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0:
for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()):
Expand Down Expand Up @@ -1000,7 +1000,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
# task test zero-shot
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task.module
player.actor_type = "task"
test(player, fabric, cfg, log_dir, "zero-shot")
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
normalized_obs = {}
for k in obs_keys:
Expand Down Expand Up @@ -383,7 +383,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
from_numpy=cfg.buffer.from_numpy,
)
# Start training
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(next(iter(local_data.values())).shape[0]):
if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0:
for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()):
Expand Down Expand Up @@ -484,7 +484,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

envs.close()
# task test few-shot
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task.module
player.actor_type = "task"
test(player, fabric, cfg, log_dir, "few-shot")
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
Expand Down Expand Up @@ -950,7 +950,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
from_numpy=cfg.buffer.from_numpy,
)
# Start training
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(next(iter(local_data.values())).shape[0]):
if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0:
tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
# task test zero-shot
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task.module
player.actor_type = "task"
test(player, fabric, cfg, log_dir, "zero-shot")
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
preprocessed_obs = {}
for k, v in obs.items():
Expand Down Expand Up @@ -382,7 +382,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
from_numpy=cfg.buffer.from_numpy,
)
# Start training
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(next(iter(local_data.values())).shape[0]):
tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau
if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0:
Expand Down Expand Up @@ -487,7 +487,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

envs.close()
# task test few-shot
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task.module
player.actor_type = "task"
test(player, fabric, cfg, log_dir, "few-shot")
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
# Sample an action given the observation received by the environment
normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys)
Expand Down Expand Up @@ -372,7 +372,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Flatten the first two dimensions: [Buffer_Size, Num_Envs]
gathered_data = {k: v.flatten(start_dim=0, end_dim=1).float() for k, v in local_data.items()}

with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
train(fabric, agent, optimizer, gathered_data, aggregator, cfg)
train_step += world_size

Expand Down Expand Up @@ -445,7 +445,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(agent.module, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def player(

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
# Sample an action given the observation received by the environment
normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys)
Expand Down Expand Up @@ -349,7 +349,7 @@ def player(
)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(agent, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
Loading

0 comments on commit 955a57f

Please sign in to comment.