-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_multiModal_parallel.py
165 lines (133 loc) · 5.57 KB
/
train_multiModal_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import gym
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize, VecVideoRecorder, VecMonitor, VecFrameStack
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.utils import set_random_seed
from gym.wrappers.rescale_action import RescaleAction
# from custom_envs.rampTaperEnv_half import SumoRamp
from gym.spaces import Box
# from custom_envs.bsmMerge import BsmMergeAllRewards as BsmMerge
# from custom_envs.bsmMerge import BsmMerge
# load simple cnn + bsm reward env
# from custom_envs.MultiMerge import MultiMerge
# load cnn + bsm all rewards env
from custom_envs.MultiMergeParallel import MultiallRewards as MultiMerge
import os
import wandb, glob
#from customFeatureExtractor import CustomCombinedExtractor
from wandb.integration.sb3 import WandbCallback
from stable_baselines3.common.monitor import Monitor
#from config_file import sac_multi_config as exp_config
import argparse
parser = argparse.ArgumentParser(description='train PPO multi model')
parser.add_argument("config", help="Config file")
parser.add_argument("--render", default=0 , help = "should render")
args = parser.parse_args()
module = __import__("config_file",fromlist= [args.config])
exp_config = getattr(module, args.config)
#print(exp_config.image_shape)
timesteps = 50000
subtimesteps = 10000
sub_timesteps = 10000
config = {
"policy_type": "MultiInputPolicy",
"total_timesteps": timesteps,
"env_name": "SumoRamp()",
"sub_timesteps": sub_timesteps
}
pdir = os.path.abspath('../')
dir = os.path.join(pdir, 'SBRampSavedFiles/wandbsavedfiles')
# env = SumoRamp(action_space=action_space, obsspaces=obsspaces, sumoParameters = sumoParameters, weights= weights, isBaseline=False)
policy_kwargs = exp_config.policy_kwargs
action_space = exp_config.action_space
image_shape = exp_config.image_shape
obsspaces = exp_config.obsspaces
weights = exp_config.weights
sumoParameters = exp_config.sumoParameters
min_action = -1
max_action = +1
video_folder = dir + '/logs/videos/'
video_length = 600
def make_env(env_id, rank, seed=0, monitor_dir = None):
"""
Utility function for multiprocessed env.
:param env_id: (str) the environment ID
:param num_env: (int) the number of environments you wish to have in subprocesses
:param seed: (int) the inital seed for RNG
:param rank: (int) index of the subprocess
"""
def _init():
env = MultiMerge(action_space=action_space, obsspaces=obsspaces, sumoParameters=sumoParameters, weights=weights,
isBaseline=False,render=0)
env.seed(seed + rank)
env = RescaleAction(env, min_action, max_action)
monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None
if monitor_path is not None:
os.makedirs(monitor_dir, exist_ok=True)
env = Monitor(env, filename=monitor_path)
return env
set_random_seed(seed)
return _init
if __name__ == '__main__':
run = wandb.init(
project="Robust-OnRampMerging",
name=f"ParallelMultiModal_{args.config}",
dir=dir,
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True, # auto-upload the videos of agents playing the game
save_code=True, # optional
magic=True
)
env_id = "MultiMerge"
num_cpu = 2# Number of processes to use
# Create the vectorized environment
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv)
env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)])
# add vstack
env = VecFrameStack(env, n_stack=4) # stack 4 frames
env = VecNormalize(env, norm_obs=True, norm_reward=True, training=True)
# env = VecVideoRecorder(env, video_folder=f"./videos/{run.id}",
# record_video_trigger=lambda x: x % config["sub_timesteps"] == 0,
# video_length=300)
env = VecMonitor(venv=env)
# eval_env = MultiMerge(action_space=action_space, obsspaces=obsspaces, sumoParameters=sumoParameters, weights=weights,
# isBaseline=False,render=0)
#code = wandb.Artifact('project-source', type='code')
#for path in glob.glob('**/*.py', recursive=True):
# code.add_file(path)
#wandb.run.use_artifact(code)
model = PPO(config["policy_type"], env,
verbose=3,
gamma=0.95,
n_steps=1200,
ent_coef=0.0905168,
learning_rate=0.005,
vf_coef=0.042202,
max_grad_norm=0.9,
gae_lambda=0.7,
n_epochs=10,
clip_range=0.3,
batch_size=1200,
tensorboard_log=f"{dir}")
#for i in range(100):
model.learn(
total_timesteps=int(config["total_timesteps"]),
#total_timesteps=int(10),
callback=WandbCallback(
gradient_save_freq=5,
model_save_freq=5000,
model_save_path=f"{dir}/models/{run.id}",
verbose=2,
),
# eval_env=eval_env,
# eval_freq=int(config["sub_timesteps"]),
# n_eval_episodes=10,
# eval_log_path=f"{dir}/eval/{run.id}",
#reset_num_timesteps= False,
)
#model.save(f"{dir}/models/{run.id}/{i}/model")
stats_path = os.path.join(f"{dir}/models/{run.id}/", "vec_normalize.pkl")
env.save(stats_path)