Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polish(nyz): polish offpolicy RL multi-gpu DDP training #679

Merged
merged 5 commits into from
Jul 13, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
polish(nyz): polish offpolicy RL multi-gpu DDP training
PaParaZz1 committed Jun 27, 2023
commit c10c861b5de010736a56053f417fae72c896c546
4 changes: 2 additions & 2 deletions ding/config/config.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
from easydict import EasyDict
from copy import deepcopy

from ding.utils import deep_merge_dicts
from ding.utils import deep_merge_dicts, get_rank
from ding.envs import get_env_cls, get_env_manager_cls, BaseEnvManager
from ding.policy import get_policy_cls
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, Coordinator, \
@@ -459,7 +459,7 @@ def compile_config(
cfg.policy.eval.evaluator.n_episode = cfg.env.n_evaluator_episode
if 'exp_name' not in cfg:
cfg.exp_name = 'default_experiment'
if save_cfg:
if save_cfg and get_rank() == 0:
if os.path.exists(cfg.exp_name) and renew_dir:
cfg.exp_name += datetime.datetime.now().strftime("_%y%m%d_%H%M%S")
try:
33 changes: 17 additions & 16 deletions ding/entry/serial_entry.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
create_serial_collector, create_serial_evaluator
from ding.config import read_config, compile_config
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from ding.utils import set_pkg_seed, get_rank
from .utils import random_collect


@@ -61,7 +61,7 @@ def serial_pipeline(
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
collector = create_serial_collector(
cfg.policy.collect.collector,
@@ -119,18 +119,19 @@ def serial_pipeline(

# Learner's after_run hook.
learner.call_hook('after_run')
import time
import pickle
import numpy as np
with open(os.path.join(cfg.exp_name, 'result.pkl'), 'wb') as f:
eval_value_raw = [d['eval_episode_return'] for d in eval_info]
final_data = {
'stop': stop,
'env_step': collector.envstep,
'train_iter': learner.train_iter,
'eval_value': np.mean(eval_value_raw),
'eval_value_raw': eval_value_raw,
'finish_time': time.ctime(),
}
pickle.dump(final_data, f)
if get_rank() == 0:
import time
import pickle
import numpy as np
with open(os.path.join(cfg.exp_name, 'result.pkl'), 'wb') as f:
eval_value_raw = [d['eval_episode_return'] for d in eval_info]
final_data = {
'stop': stop,
'env_step': collector.envstep,
'train_iter': learner.train_iter,
'eval_value': np.mean(eval_value_raw),
'eval_value_raw': eval_value_raw,
'finish_time': time.ctime(),
}
pickle.dump(final_data, f)
return policy
9 changes: 6 additions & 3 deletions ding/policy/policy_factory.py
Original file line number Diff line number Diff line change
@@ -35,9 +35,12 @@ def forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]:
actions = {}
for env_id in data:
if not isinstance(action_space, list):
action = torch.as_tensor(action_space.sample())
if isinstance(action_space, gym.spaces.MultiDiscrete):
action = [torch.LongTensor([v]) for v in action]
if isinstance(action_space, gym.spaces.Discrete):
action = torch.LongTensor([action_space.sample()])
elif isinstance(action_space, gym.spaces.MultiDiscrete):
action = [torch.LongTensor([v]) for v in action_space.sample()]
else:
action = torch.as_tensor(action_space.sample())
actions[env_id] = {'action': action}
elif 'global_state' in data[env_id].keys():
# for smac
5 changes: 3 additions & 2 deletions ding/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -31,9 +31,10 @@
from .fast_copy import fastcopy
from .bfs_helper import get_vi_sequence

if ding.enable_linklink:
if ding.enable_linklink: # False as default
from .linklink_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \
allreduce, broadcast, DistContext, allreduce_async, synchronize
else:
from .pytorch_ddp_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \
allreduce, broadcast, DistContext, allreduce_async, synchronize
allreduce, broadcast, DDPContext, allreduce_async, synchronize, reduce_data, broadcast_object_list, \
to_ddp_config, allreduce_data
44 changes: 41 additions & 3 deletions ding/utils/pytorch_ddp_dist_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Callable, Tuple, List, Any
from typing import Callable, Tuple, List, Any, Union
from easydict import EasyDict

import os
import numpy as np
import torch
import torch.distributed as dist
@@ -30,6 +31,7 @@ def get_world_size() -> int:

broadcast = dist.broadcast
allgather = dist.all_gather
broadcast_object_list = dist.broadcast_object_list


def allreduce(x: torch.Tensor) -> None:
@@ -42,6 +44,35 @@ def allreduce_async(name: str, x: torch.Tensor) -> None:
dist.all_reduce(x, async_op=True)


def reduce_data(x: Union[int, float, torch.Tensor], dst: int) -> Union[int, float, torch.Tensor]:
if np.isscalar(x):
x_tensor = torch.as_tensor([x]).cuda()
dist.reduce(x_tensor, dst)
return x_tensor.item()
elif isinstance(x, torch.Tensor):
dist.reduce(x, dst)
return x
else:
raise TypeError("not supported type: {}".format(type(x)))


def allreduce_data(x: Union[int, float, torch.Tensor], op: str) -> Union[int, float, torch.Tensor]:
assert op in ['sum', 'avg'], op
if np.isscalar(x):
x_tensor = torch.as_tensor([x]).cuda()
dist.all_reduce(x_tensor)
if op == 'avg':
x_tensor.div_(get_world_size())
return x_tensor.item()
elif isinstance(x, torch.Tensor):
dist.all_reduce(x)
if op == 'avg':
x.div_(get_world_size())
return x
else:
raise TypeError("not supported type: {}".format(type(x)))


synchronize = torch.cuda.synchronize


@@ -119,7 +150,7 @@ def dist_finalize() -> None:
pass


class DistContext:
class DDPContext:

def __init__(self) -> None:
pass
@@ -146,3 +177,10 @@ def simple_group_split(world_size: int, rank: int, num_groups: int) -> List:
groups.append(dist.new_group(rank_list[i]))
group_size = world_size // num_groups
return groups[rank // group_size]


def to_ddp_config(cfg: EasyDict) -> EasyDict:
w = get_world_size()
cfg.policy.learn.batch_size = int(np.ceil(cfg.policy.learn.batch_size / w))
cfg.policy.collect.n_sample = int(np.ceil(cfg.policy.collect.n_sample) / w)
return cfg
16 changes: 3 additions & 13 deletions ding/worker/collector/interaction_serial_evaluator.py
Original file line number Diff line number Diff line change
@@ -2,12 +2,11 @@
from collections import namedtuple
import numpy as np
import torch
import torch.distributed as dist

from ding.envs import BaseEnvManager
from ding.torch_utils import to_tensor, to_ndarray, to_item
from ding.utils import build_logger, EasyTimer, SERIAL_EVALUATOR_REGISTRY
from ding.utils import get_world_size, get_rank
from ding.utils import get_world_size, get_rank, broadcast_object_list
from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor


@@ -65,10 +64,7 @@ def __init__(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name
)
else:
self._logger, _ = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
)
self._tb_logger = None
self._logger, self._tb_logger = None, None # for close elegantly
self.reset(policy, env)

self._timer = EasyTimer()
@@ -199,12 +195,6 @@ def eval(
- stop_flag (:obj:`bool`): Whether this training program can be ended.
- return_info (:obj:`dict`): Current evaluation return information.
'''
if get_world_size() > 1:
# sum up envstep to rank0
envstep_tensor = torch.tensor(envstep).cuda()
dist.reduce(envstep_tensor, dst=0)
envstep = envstep_tensor.item()

# evaluator only work on rank0
stop_flag, return_info = False, []
if get_rank() == 0:
@@ -308,7 +298,7 @@ def eval(

if get_world_size() > 1:
objects = [stop_flag, return_info]
dist.broadcast_object_list(objects, src=0)
broadcast_object_list(objects, src=0)
stop_flag, return_info = objects

return_info = to_item(return_info)
56 changes: 39 additions & 17 deletions ding/worker/collector/sample_serial_collector.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,8 @@
import torch

from ding.envs import BaseEnvManager
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \
broadcast_object_list, allreduce_data
from ding.torch_utils import to_tensor, to_ndarray
from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions

@@ -52,16 +53,25 @@ def __init__(
self._cfg = cfg
self._timer = EasyTimer()
self._end_flag = False
self._rank = get_rank()
self._world_size = get_world_size()

if tb_logger is not None:
if self._rank == 0:
if tb_logger is not None:
self._logger, _ = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
)
else:
self._logger, _ = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger(
path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
)
self._tb_logger = None

self.reset(policy, env)

def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
@@ -184,8 +194,9 @@ def close(self) -> None:
return
self._end_flag = True
self._env.close()
self._tb_logger.flush()
self._tb_logger.close()
if self._tb_logger:
self._tb_logger.flush()
self._tb_logger.close()

def __del__(self) -> None:
"""
@@ -231,6 +242,8 @@ def collect(
if policy_kwargs is None:
policy_kwargs = {}
collected_sample = 0
collected_step = 0
collected_episode = 0
return_data = []

while collected_sample < n_sample:
@@ -276,7 +289,7 @@ def collect(
transition['collect_iter'] = train_iter
self._traj_buffer[env_id].append(transition)
self._env_info[env_id]['step'] += 1
self._total_envstep_count += 1
collected_step += 1
# prepare data
if timestep.done or len(self._traj_buffer[env_id]) == self._traj_len:
# If policy is r2d2:
@@ -294,7 +307,6 @@ def collect(
transitions = to_tensor_transitions(self._traj_buffer[env_id], not self._deepcopy_obs)
train_sample = self._policy.get_train_sample(transitions)
return_data.extend(train_sample)
self._total_train_sample_count += len(train_sample)
self._env_info[env_id]['train_sample'] += len(train_sample)
collected_sample += len(train_sample)
self._traj_buffer[env_id].clear()
@@ -303,7 +315,7 @@ def collect(

# If env is done, record episode info and reset
if timestep.done:
self._total_episode_count += 1
collected_episode += 1
reward = timestep.info['eval_episode_return']
info = {
'reward': reward,
@@ -315,6 +327,17 @@ def collect(
# Env reset is done by env_manager automatically
self._policy.reset([env_id])
self._reset_stat(env_id)
collected_duration = sum([d['time'] for d in self._episode_info])
# reduce data when enables DDP
if self._world_size > 1:
collected_sample = allreduce_data(collected_sample, 'sum')
collected_step = allreduce_data(collected_step, 'sum')
collected_episode = allreduce_data(collected_episode, 'sum')
collected_duration = allreduce_data(collected_duration, 'sum')
self._total_envstep_count += collected_step
self._total_episode_count += collected_episode
self._total_duration += collected_duration
self._total_train_sample_count += collected_sample
# log
if record_random_collect: # default is true, but when random collect, record_random_collect is False
self._output_log(train_iter)
@@ -333,19 +356,20 @@ def collect(
def _output_log(self, train_iter: int) -> None:
"""
Overview:
Print the output log information. You can refer to Docs/Best Practice/How to understand\
training generated folders/Serial mode/log/collector for more details.
Print the output log information. You can refer to the docs of `Best Practice` to understand \
the training generated logs and tensorboards.
Arguments:
- train_iter (:obj:`int`): the number of training iteration.
"""
if self._rank != 0:
return
if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0:
self._last_train_iter = train_iter
episode_count = len(self._episode_info)
envstep_count = sum([d['step'] for d in self._episode_info])
train_sample_count = sum([d['train_sample'] for d in self._episode_info])
duration = sum([d['time'] for d in self._episode_info])
episode_return = [d['reward'] for d in self._episode_info]
self._total_duration += duration
info = {
'episode_count': episode_count,
'envstep_count': envstep_count,
@@ -355,15 +379,13 @@ def _output_log(self, train_iter: int) -> None:
'avg_envstep_per_sec': envstep_count / duration,
'avg_train_sample_per_sec': train_sample_count / duration,
'avg_episode_per_sec': episode_count / duration,
'collect_time': duration,
'reward_mean': np.mean(episode_return),
'reward_std': np.std(episode_return),
'reward_max': np.max(episode_return),
'reward_min': np.min(episode_return),
'total_envstep_count': self._total_envstep_count,
'total_train_sample_count': self._total_train_sample_count,
'total_episode_count': self._total_episode_count,
'total_duration': self._total_duration,
# 'each_reward': episode_return,
}
self._episode_info.clear()
56 changes: 35 additions & 21 deletions ding/worker/replay_buffer/advanced_buffer.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@

from ding.worker.replay_buffer import IBuffer
from ding.utils import SumSegmentTree, MinSegmentTree, BUFFER_REGISTRY
from ding.utils import LockContext, LockContextType, build_logger
from ding.utils import LockContext, LockContextType, build_logger, get_rank
from ding.utils.autolog import TickTime
from .utils import UsedDataRemover, generate_id, SampledDataAttrMonitor, PeriodicThruputMonitor, ThruputController

@@ -106,6 +106,7 @@ def __init__(
self._instance_name = instance_name
self._end_flag = False
self._cfg = cfg
self._rank = get_rank()
self._replay_buffer_size = self._cfg.replay_buffer_size
self._deepcopy = self._cfg.deepcopy
# ``_data`` is a circular queue to store data (full data or meta data)
@@ -163,16 +164,22 @@ def __init__(

# Monitor & Logger
monitor_cfg = self._cfg.monitor
if tb_logger is not None:
if self._rank == 0:
if tb_logger is not None:
self._logger, _ = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name),
self._instance_name,
)
else:
self._logger, _ = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
)
self._tb_logger = tb_logger
else:
self._logger, self._tb_logger = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name),
self._instance_name,
)
self._tb_logger = None
self._start_time = time.time()
# Sampled data attributes.
self._cur_learner_iter = -1
@@ -183,9 +190,10 @@ def __init__(
)
self._sampled_data_attr_print_freq = monitor_cfg.sampled_data_attr.print_freq
# Periodic thruput.
self._periodic_thruput_monitor = PeriodicThruputMonitor(
self._instance_name, monitor_cfg.periodic_thruput, self._logger, self._tb_logger
)
if self._rank == 0:
self._periodic_thruput_monitor = PeriodicThruputMonitor(
self._instance_name, monitor_cfg.periodic_thruput, self._logger, self._tb_logger
)

# Used data remover
self._enable_track_used_data = self._cfg.enable_track_used_data
@@ -210,9 +218,10 @@ def close(self) -> None:
return
self._end_flag = True
self.clear()
self._periodic_thruput_monitor.close()
self._tb_logger.flush()
self._tb_logger.close()
if self._rank == 0:
self._periodic_thruput_monitor.close()
self._tb_logger.flush()
self._tb_logger.close()
if self._enable_track_used_data:
self._used_data_remover.close()

@@ -374,7 +383,8 @@ def _append(self, ori_data: Any, cur_collector_envstep: int = -1) -> None:
self._set_weight(data)
self._data[self._tail] = data
self._valid_count += 1
self._periodic_thruput_monitor.valid_count = self._valid_count
if self._rank == 0:
self._periodic_thruput_monitor.valid_count = self._valid_count
self._tail = (self._tail + 1) % self._replay_buffer_size
self._next_unique_id += 1
self._monitor_update_of_push(1, cur_collector_envstep)
@@ -435,7 +445,8 @@ def _extend(self, ori_data: List[Any], cur_collector_envstep: int = -1) -> None:
data_start = 0
valid_data_start += L
self._valid_count += len(valid_data)
self._periodic_thruput_monitor.valid_count = self._valid_count
if self._rank == 0:
self._periodic_thruput_monitor.valid_count = self._valid_count
# Update ``tail`` and ``next_unique_id`` after the whole list is pushed into buffer.
self._tail = (self._tail + length) % self._replay_buffer_size
self._next_unique_id += length
@@ -568,8 +579,9 @@ def _remove(self, idx: int, use_too_many_times: bool = False) -> None:
if self._enable_track_used_data:
self._used_data_remover.add_used_data(self._data[idx])
self._valid_count -= 1
self._periodic_thruput_monitor.valid_count = self._valid_count
self._periodic_thruput_monitor.remove_data_count += 1
if self._rank == 0:
self._periodic_thruput_monitor.valid_count = self._valid_count
self._periodic_thruput_monitor.remove_data_count += 1
self._data[idx] = None
self._sum_tree[idx] = self._sum_tree.neutral_element
self._min_tree[idx] = self._min_tree.neutral_element
@@ -624,7 +636,8 @@ def _monitor_update_of_push(self, add_count: int, cur_collector_envstep: int = -
- add_count (:obj:`int`): How many datas are added into buffer.
- cur_collector_envstep (:obj:`int`): Collector envstep, passed in by collector.
"""
self._periodic_thruput_monitor.push_data_count += add_count
if self._rank == 0:
self._periodic_thruput_monitor.push_data_count += add_count
if self._use_thruput_controller:
self._thruput_controller.history_push_count += add_count
self._cur_collector_envstep = cur_collector_envstep
@@ -639,7 +652,8 @@ def _monitor_update_of_sample(self, sample_data: list, cur_learner_iter: int) ->
e.g. use, priority, staleness, etc.
- cur_learner_iter (:obj:`int`): Learner iteration, passed in by learner.
"""
self._periodic_thruput_monitor.sample_data_count += len(sample_data)
if self._rank == 0:
self._periodic_thruput_monitor.sample_data_count += len(sample_data)
if self._use_thruput_controller:
self._thruput_controller.history_sample_count += len(sample_data)
self._cur_learner_iter = cur_learner_iter
@@ -668,7 +682,7 @@ def _monitor_update_of_sample(self, sample_data: list, cur_learner_iter: int) ->
'staleness_max': self._sampled_data_attr_monitor.max['staleness'](),
'beta': self._beta,
}
if self._sampled_data_attr_print_count % self._sampled_data_attr_print_freq == 0:
if self._sampled_data_attr_print_count % self._sampled_data_attr_print_freq == 0 and self._rank == 0:
self._logger.info("=== Sample data {} Times ===".format(self._sampled_data_attr_print_count))
self._logger.info(self._logger.get_tabulate_vars_hor(out_dict))
for k, v in out_dict.items():
100 changes: 0 additions & 100 deletions dizoo/atari/config/parallel/qbert_dqn_config.py

This file was deleted.

117 changes: 0 additions & 117 deletions dizoo/atari/config/parallel/qbert_dqn_config_k8s.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -7,18 +7,14 @@
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=10000000000,
env_id='SpaceInvadersNoFrameskip-v4',
#'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
frame_stack=4,
manager=dict(shared_memory=False, ),
# The path to save the game replay
replay_path='./spaceinvaders_dqn_seed0/video',
),
policy=dict(
cuda=True,
priority=False,
load_path="./spaceinvaders_dqn_seed0/ckpt/ckpt_best.pth.tar",
random_collect_size=5000,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
@@ -61,4 +57,4 @@
if __name__ == '__main__':
# or you can enter ding -m serial -c spaceinvaders_dqn_config.py -s 0
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0)
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7))
Original file line number Diff line number Diff line change
@@ -7,16 +7,14 @@
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=10000000000,
env_id='SpaceInvadersNoFrameskip-v4',
#'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
frame_stack=4,
manager=dict(shared_memory=False, )
),
policy=dict(
cuda=True,
multi_gpu=True,
priority=False,
random_collect_size=5000,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
@@ -58,6 +56,7 @@

if __name__ == '__main__':
from ding.entry import serial_pipeline
from ding.utils import DistContext
with DistContext():
serial_pipeline((main_config, create_config), seed=0)
from ding.utils import DDPContext, to_ddp_config
with DDPContext():
main_config = to_ddp_config(main_config)
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7))
Original file line number Diff line number Diff line change
@@ -7,11 +7,9 @@
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=10000000000,
env_id='SpaceInvadersNoFrameskip-v4',
#'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
frame_stack=4,
manager=dict(shared_memory=False, )
),
policy=dict(
cuda=True,
@@ -60,4 +58,4 @@
from ding.model.template.q_learning import DQN
from ding.torch_utils import DataParallel
model = DataParallel(DQN(obs_shape=[4, 84, 84], action_shape=6))
serial_pipeline((main_config, create_config), seed=0, model=model)
serial_pipeline((main_config, create_config), seed=0, model=model, max_env_step=int(1e7))
Original file line number Diff line number Diff line change
@@ -59,4 +59,4 @@
if __name__ == '__main__':
# or you can enter ding -m serial -c spaceinvaders_qrdqn_config.py -s 0
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0)
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7))