From 1b2dd2a6116f20684afd486425340defdbcbfd67 Mon Sep 17 00:00:00 2001 From: belerico Date: Wed, 27 Mar 2024 18:56:37 +0100 Subject: [PATCH] Wrap actor with single-device fabric --- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 4 +++- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 8 ++++++-- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 4 +++- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 8 ++++++-- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 4 +++- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 9 +++++++-- 6 files changed, 28 insertions(+), 9 deletions(-) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 99086504..9ad04faa 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -24,6 +24,7 @@ from sheeprl.algos.p2e_dv1.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -835,8 +836,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task player.actor_type = "task" + fabric_player = get_single_device_fabric(fabric) + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) test(player, fabric, cfg, log_dir, "zero-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 63947b7d..6da54a91 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -19,6 +19,7 @@ from sheeprl.algos.p2e_dv1.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -31,6 +32,9 @@ @register_algorithm() def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): + # Single-device fabric object + fabric_player = get_single_device_fabric(fabric) + device = fabric.device rank = fabric.global_rank world_size = fabric.world_size @@ -345,8 +349,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: if player.actor_type == "exploration": - player.actor = actor_task player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(cfg.algo.per_rank_gradient_steps): sample = rb.sample_tensors( @@ -451,8 +455,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): envs.close() # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) test(player, fabric, cfg, log_dir, "few-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index 11c13fe2..54fad041 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -24,6 +24,7 @@ from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer from sheeprl.utils.distribution import OneHotCategoricalValidateArgs from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -1000,8 +1001,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task player.actor_type = "task" + fabric_player = get_single_device_fabric(fabric) + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) test(player, fabric, cfg, log_dir, "zero-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index b0a18af7..d00f8c69 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -19,6 +19,7 @@ from sheeprl.algos.p2e_dv2.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -31,6 +32,9 @@ @register_algorithm() def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): + # Single-device fabric object + fabric_player = get_single_device_fabric(fabric) + device = fabric.device rank = fabric.global_rank world_size = fabric.world_size @@ -368,8 +372,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: if player.actor_type == "exploration": - player.actor = actor_task player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) n_samples = ( cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps ) @@ -484,8 +488,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): envs.close() # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) test(player, fabric, cfg, log_dir, "few-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index a33397dc..106cb655 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -28,6 +28,7 @@ TwoHotEncodingDistribution, ) from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -1082,8 +1083,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task player.actor_type = "task" + fabric_player = get_single_device_fabric(fabric) + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) test(player, fabric, cfg, log_dir, "zero-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index aae93178..b5b482f8 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -17,6 +17,7 @@ from sheeprl.algos.p2e_dv3.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.env import make_env +from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -26,6 +27,9 @@ @register_algorithm() def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): + # Single-device fabric object + fabric_player = get_single_device_fabric(fabric) + device = fabric.device rank = fabric.global_rank world_size = fabric.world_size @@ -370,8 +374,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: if player.actor_type == "exploration": - player.actor = actor_task player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) local_data = rb.sample_tensors( cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, @@ -487,10 +491,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ) envs.close() + # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: - player.actor = actor_task player.actor_type = "task" + player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) test(player, fabric, cfg, log_dir, "few-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: