-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmain.py
62 lines (52 loc) · 1.32 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gym
import time
import yaml
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecTransposeImage
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import EvalCallback
# Get train environment configs
with open('scripts/config.yml', 'r') as f:
env_config = yaml.safe_load(f)
# Create a DummyVecEnv
env = DummyVecEnv([lambda: Monitor(
gym.make(
"scripts:airsim-env-v0",
ip_address="127.0.0.1",
image_shape=(50,50,3),
env_config=env_config["TrainEnv"]
)
)])
# Wrap env as VecTransposeImage (Channel last to channel first)
env = VecTransposeImage(env)
# Initialize PPO
model = PPO(
'CnnPolicy',
env,
verbose=1,
seed=42,
device="cuda",
tensorboard_log="./tb_logs/",
)
# Evaluation callback
callbacks = []
eval_callback = EvalCallback(
env,
callback_on_new_best=None,
n_eval_episodes=4,
best_model_save_path=".",
log_path=".",
eval_freq=500,
)
callbacks.append(eval_callback)
kwargs = {}
kwargs["callback"] = callbacks
log_name = "ppo_run_" + str(time.time())
model.learn(
total_timesteps=150000,
tb_log_name=log_name,
**kwargs
)
# Save policy weights
model.save("ppo_navigation_policy")