Skip to content

Commit

Permalink
Wrap actor with single-device fabric
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed Mar 27, 2024
1 parent 2c7286f commit 1b2dd2a
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 9 deletions.
4 changes: 3 additions & 1 deletion sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sheeprl.algos.p2e_dv1.agent import build_agent
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
Expand Down Expand Up @@ -835,8 +836,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
envs.close()
# task test zero-shot
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task
player.actor_type = "task"
fabric_player = get_single_device_fabric(fabric)
player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task))
test(player, fabric, cfg, log_dir, "zero-shot")

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
8 changes: 6 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sheeprl.algos.p2e_dv1.agent import build_agent
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
Expand All @@ -31,6 +32,9 @@

@register_algorithm()
def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
# Single-device fabric object
fabric_player = get_single_device_fabric(fabric)

device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
Expand Down Expand Up @@ -345,8 +349,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
# Train the agent
if update >= learning_starts and updates_before_training <= 0:
if player.actor_type == "exploration":
player.actor = actor_task
player.actor_type = "task"
player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task))
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
for i in range(cfg.algo.per_rank_gradient_steps):
sample = rb.sample_tensors(
Expand Down Expand Up @@ -451,8 +455,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
envs.close()
# task test few-shot
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task
player.actor_type = "task"
player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task))
test(player, fabric, cfg, log_dir, "few-shot")

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
4 changes: 3 additions & 1 deletion sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer
from sheeprl.utils.distribution import OneHotCategoricalValidateArgs
from sheeprl.utils.env import make_env
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
Expand Down Expand Up @@ -1000,8 +1001,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
envs.close()
# task test zero-shot
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task
player.actor_type = "task"
fabric_player = get_single_device_fabric(fabric)
player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task))
test(player, fabric, cfg, log_dir, "zero-shot")

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
8 changes: 6 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sheeprl.algos.p2e_dv2.agent import build_agent
from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
Expand All @@ -31,6 +32,9 @@

@register_algorithm()
def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
# Single-device fabric object
fabric_player = get_single_device_fabric(fabric)

device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
Expand Down Expand Up @@ -368,8 +372,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
# Train the agent
if update >= learning_starts and updates_before_training <= 0:
if player.actor_type == "exploration":
player.actor = actor_task
player.actor_type = "task"
player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task))
n_samples = (
cfg.algo.per_rank_pretrain_steps if update == learning_starts else cfg.algo.per_rank_gradient_steps
)
Expand Down Expand Up @@ -484,8 +488,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
envs.close()
# task test few-shot
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task
player.actor_type = "task"
player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task))
test(player, fabric, cfg, log_dir, "few-shot")

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
4 changes: 3 additions & 1 deletion sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TwoHotEncodingDistribution,
)
from sheeprl.utils.env import make_env
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
Expand Down Expand Up @@ -1082,8 +1083,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
envs.close()
# task test zero-shot
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task
player.actor_type = "task"
fabric_player = get_single_device_fabric(fabric)
player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task))
test(player, fabric, cfg, log_dir, "zero-shot")

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
9 changes: 7 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sheeprl.algos.p2e_dv3.agent import build_agent
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.fabric import get_single_device_fabric
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
Expand All @@ -26,6 +27,9 @@

@register_algorithm()
def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
# Single-device fabric object
fabric_player = get_single_device_fabric(fabric)

device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
Expand Down Expand Up @@ -370,8 +374,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
# Train the agent
if update >= learning_starts and updates_before_training <= 0:
if player.actor_type == "exploration":
player.actor = actor_task
player.actor_type = "task"
player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task))
local_data = rb.sample_tensors(
cfg.algo.per_rank_batch_size,
sequence_length=cfg.algo.per_rank_sequence_length,
Expand Down Expand Up @@ -487,10 +491,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
)

envs.close()

# task test few-shot
if fabric.is_global_zero and cfg.algo.run_test:
player.actor = actor_task
player.actor_type = "task"
player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task))
test(player, fabric, cfg, log_dir, "few-shot")

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down

0 comments on commit 1b2dd2a

Please sign in to comment.