Skip to content

Commit

Permalink
fix(nyz): fix ppo parallel bug (#709)
Browse files Browse the repository at this point in the history
PaParaZz1 committed Sep 15, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 9299826 commit 6e93b4c
Showing 3 changed files with 45 additions and 2 deletions.
37 changes: 35 additions & 2 deletions ding/example/ppo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
"""
# Example of PPO pipeline
Use the pipeline on a single process:
> python3 -u ding/example/ppo.py
Use the pipeline on multiple processes:
We surpose there are N processes (workers) = 1 learner + 1 evaluator + (N-2) collectors
## First Example —— Execute on one machine with multi processes.
Execute 4 processes with 1 learner + 1 evaluator + 2 collectors
Remember to keep them connected by mesh to ensure that they can exchange information with each other.
> ditask --package . --main ding.example.ppo.main --parallel-workers 4 --topology mesh
"""
import gym
from ditk import logging
from ding.model import VAC
@@ -8,14 +26,14 @@
from ding.framework import task, ding_init
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
gae_estimator, online_logger
gae_estimator, online_logger, ContextExchanger, ModelExchanger
from ding.utils import set_pkg_seed
from dizoo.classic_control.cartpole.config.cartpole_ppo_config import main_config, create_config


def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0)
ding_init(cfg)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
@@ -32,6 +50,21 @@ def main():
model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)

# Consider the case with multiple processes
if task.router.is_active:
# You can use labels to distinguish between workers with different roles,
# here we use node_id to distinguish.
if task.router.node_id == 0:
task.add_role(task.role.LEARNER)
elif task.router.node_id == 1:
task.add_role(task.role.EVALUATOR)
else:
task.add_role(task.role.COLLECTOR)

# Sync their context and model between each worker.
task.use(ContextExchanger(skip_n_iter=1))
task.use(ModelExchanger(model))

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(gae_estimator(cfg, policy.collect_mode))
6 changes: 6 additions & 0 deletions ding/framework/middleware/functional/advantage_estimator.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,8 @@ def gae_estimator(cfg: EasyDict, policy: Policy, buffer_: Optional[Buffer] = Non
- policy (:obj:`Policy`): Policy in `policy.collect_mode`, used to get model to calculate value.
- buffer\_ (:obj:`Optional[Buffer]`): The `buffer_` to push the processed data in if `buffer_` is not None.
"""
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()

model = policy.get_attribute('model')
# Unify the shape of obs and action
@@ -104,6 +106,8 @@ def _gae(ctx: "OnlineRLContext"):


def ppof_adv_estimator(policy: Policy) -> Callable:
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()

def _estimator(ctx: "OnlineRLContext"):
data = ttorch_collate(ctx.trajectories, cat_1dim=True)
@@ -118,6 +122,8 @@ def _estimator(ctx: "OnlineRLContext"):


def montecarlo_return_estimator(policy: Policy) -> Callable:
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()

def pg_policy_get_train_sample(data):
assert data[-1]['done'], "PG needs a complete epsiode"
4 changes: 4 additions & 0 deletions ding/framework/middleware/functional/trainer.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,8 @@ def trainer(cfg: EasyDict, policy: Policy, log_freq: int = 100) -> Callable:
- policy (:obj:`Policy`): The policy to be trained in step-by-step mode.
- log_freq (:obj:`int`): The frequency (iteration) of showing log.
"""
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()

def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
"""
@@ -60,6 +62,8 @@ def multistep_trainer(policy: Policy, log_freq: int = 100) -> Callable:
- policy (:obj:`Policy`): The policy specialized for multi-step training.
- log_freq (:obj:`int`): The frequency (iteration) of showing log.
"""
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()
last_log_iter = -1

def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):

0 comments on commit 6e93b4c

Please sign in to comment.