-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_multiModal_final.py
123 lines (103 loc) · 3.91 KB
/
train_multiModal_final.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize, VecVideoRecorder, VecMonitor, VecFrameStack
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.utils import set_random_seed
from gym.wrappers.rescale_action import RescaleAction
from gym.spaces import Box
from custom_envs.MultiMerge import MultiMergeAllRewards as MultiMerge
import os
import wandb, glob
from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.monitor import Monitor
import argparse
parser = argparse.ArgumentParser(description='train PPO multi model')
parser.add_argument("config", help="Config file")
parser.add_argument("--render", default=0 , help = "should render")
args = parser.parse_args()
module = __import__("config_file",fromlist= [args.config])
exp_config = getattr(module, args.config)
timesteps = 500000
config = {
"policy_type": "MultiInputPolicy",
"total_timesteps": timesteps,
"env_name": "SumoRamp()",
}
pdir = os.path.abspath('../')
dir = os.path.join(pdir, 'SBRampSavedFiles/wandbsavedfiles')
policy_kwargs = exp_config.policy_kwargs
action_space = exp_config.action_space
image_shape = exp_config.image_shape
obsspaces = exp_config.obsspaces
weights = exp_config.weights
sumoParameters = exp_config.sumoParameters
min_action = -1
max_action = +1
video_folder = dir + '/logs/videos/'
video_length = 600
def make_env(env_id, rank, seed=0, monitor_dir = None):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
"""
def _init():
env = MultiMerge(action_space=action_space, obsspaces=obsspaces, sumoParameters=sumoParameters, weights=weights,
isBaseline=False,render=0)
env.seed(seed + rank)
env = RescaleAction(env, min_action, max_action)
monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
if monitor_path is not None:
os.makedirs(monitor_dir, exist_ok=True)
return env
set_random_seed(seed)
return _init
if __name__ == '__main__':
run = wandb.init(
project="RMMRL-Training",
name=f"MultiModal_NoNoise",
dir=dir,
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True, # auto-upload the videos of agents playing the game
save_code=True, # optional
magic=True
)
env_id = "MultiMerge"
num_cpu = 16# Number of processes to use
# Create the vectorized environment
env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)])
env = VecFrameStack(env, n_stack=4) # stack 4 frames
env = VecNormalize(env, norm_obs=True, norm_reward=True, training=True)
env = VecMonitor(venv=env)
model = PPO(
config["policy_type"],
env,
verbose=3,
policy_kwargs=policy_kwargs,
gamma=0.99,
n_steps=512,
learning_rate=0.0001,
vf_coef=0.042202,
max_grad_norm=0.9,
gae_lambda=0.95,
n_epochs=10,
clip_range=0.2,
batch_size=256,
tensorboard_log=f"{dir}",
)
model.learn(
total_timesteps=int(config["total_timesteps"]),
callback=WandbCallback(
gradient_save_freq=5,
model_save_freq=5000,
model_save_path=f"{dir}/models/{run.id}",
verbose=2,
),
)
stats_path = os.path.join(f"{dir}/models/{run.id}/", "vec_normalize.pkl")
env.save(stats_path)