Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/optimizations #177

Merged
merged 3 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions howto/register_external_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
# Sample an action given the observation received by the environment
normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys)
Expand Down Expand Up @@ -370,7 +370,7 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(agent.module, fabric, cfg, log_dir)

# Optional part in case you want to give the possibility to register your models with MLFlow
Expand Down
4 changes: 2 additions & 2 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
# Sample an action given the observation received by the environment
normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys)
Expand Down Expand Up @@ -367,7 +367,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(agent.module, fabric, cfg, log_dir)

# Optional part in case you want to give the possibility to register your models with MLFlow
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
Expand Down Expand Up @@ -681,7 +681,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Train the agent
if update > learning_starts and updates_before_training <= 0:
# Start training
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(cfg.algo.per_rank_gradient_steps):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
Expand Down Expand Up @@ -775,7 +775,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
Expand Down Expand Up @@ -735,7 +735,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device=fabric.device,
from_numpy=cfg.buffer.from_numpy,
)
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(next(iter(local_data.values())).shape[0]):
if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0:
for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()):
Expand Down Expand Up @@ -828,7 +828,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
Expand Down Expand Up @@ -678,7 +678,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device=fabric.device,
from_numpy=cfg.buffer.from_numpy,
)
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(next(iter(local_data.values())).shape[0]):
if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0:
tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau
Expand Down Expand Up @@ -775,7 +775,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir, sample_actions=True)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def train(
)
actor_data = {k: actor_data[k][next(iter(actor_sampler))] for k in actor_data.keys()}

with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
# Update the soft-critic
for batch_idxes in critic_sampler:
critic_batch_data = {k: critic_data[k][batch_idxes] for k in critic_data.keys()}
Expand Down Expand Up @@ -283,7 +283,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
# Sample an action given the observation received by the environment
actions, _ = agent.actor.module(torch.from_numpy(obs).to(device))
Expand Down Expand Up @@ -385,7 +385,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(agent.actor.module, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
Expand Down Expand Up @@ -721,7 +721,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Train the agent
if update >= learning_starts and updates_before_training <= 0:
# Start training
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(cfg.algo.per_rank_gradient_steps):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
Expand Down Expand Up @@ -835,7 +835,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
# task test zero-shot
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task.module
player.actor_type = "task"
test(player, fabric, cfg, log_dir, "zero-shot")
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
normalized_obs = {}
for k in obs_keys:
Expand Down Expand Up @@ -349,7 +349,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
if player.actor_type == "exploration":
player.actor = actor_task.module
player.actor_type = "task"
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(cfg.algo.per_rank_gradient_steps):
sample = rb.sample_tensors(
batch_size=cfg.algo.per_rank_batch_size,
Expand Down Expand Up @@ -452,7 +452,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

envs.close()
# task test few-shot
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task.module
player.actor_type = "task"
test(player, fabric, cfg, log_dir, "few-shot")
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
Expand Down Expand Up @@ -881,7 +881,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
from_numpy=cfg.buffer.from_numpy,
)
# Start training
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(next(iter(local_data.values())).shape[0]):
if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0:
for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()):
Expand Down Expand Up @@ -1000,7 +1000,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
# task test zero-shot
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task.module
player.actor_type = "task"
test(player, fabric, cfg, log_dir, "zero-shot")
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
normalized_obs = {}
for k in obs_keys:
Expand Down Expand Up @@ -383,7 +383,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
from_numpy=cfg.buffer.from_numpy,
)
# Start training
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(next(iter(local_data.values())).shape[0]):
if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0:
for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()):
Expand Down Expand Up @@ -484,7 +484,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

envs.close()
# task test few-shot
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task.module
player.actor_type = "task"
test(player, fabric, cfg, log_dir, "few-shot")
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
Expand Down Expand Up @@ -950,7 +950,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
from_numpy=cfg.buffer.from_numpy,
)
# Start training
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(next(iter(local_data.values())).shape[0]):
if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0:
tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
# task test zero-shot
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task.module
player.actor_type = "task"
test(player, fabric, cfg, log_dir, "zero-shot")
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
preprocessed_obs = {}
for k, v in obs.items():
Expand Down Expand Up @@ -382,7 +382,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
from_numpy=cfg.buffer.from_numpy,
)
# Start training
with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(next(iter(local_data.values())).shape[0]):
tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau
if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0:
Expand Down Expand Up @@ -487,7 +487,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):

envs.close()
# task test few-shot
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task.module
player.actor_type = "task"
test(player, fabric, cfg, log_dir, "few-shot")
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
# Sample an action given the observation received by the environment
normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys)
Expand Down Expand Up @@ -372,7 +372,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Flatten the first two dimensions: [Buffer_Size, Num_Envs]
gathered_data = {k: v.flatten(start_dim=0, end_dim=1).float() for k, v in local_data.items()}

with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)):
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
train(fabric, agent, optimizer, gathered_data, aggregator, cfg)
train_step += world_size

Expand Down Expand Up @@ -445,7 +445,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(agent.module, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def player(

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)):
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
with torch.no_grad():
# Sample an action given the observation received by the environment
normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys)
Expand Down Expand Up @@ -349,7 +349,7 @@ def player(
)

envs.close()
if fabric.is_global_zero:
if fabric.is_global_zero and cfg.algo.run_test:
test(agent, fabric, cfg, log_dir)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
Loading