-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathray_dqn_agent_ddpg.py
101 lines (79 loc) · 3.75 KB
/
ray_dqn_agent_ddpg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import ray
import ray.rllib.agents.ddpg as ddpg
from ray.rllib.agents.ddpg import DDPGTrainer
from ray.tune.logger import pretty_print
import gym
import gym_cityflow
from gym_cityflow.envs.cityflow_env import CityflowGymEnv
from utility import parse_roadnet
import logging
from datetime import datetime
from tqdm import tqdm
import argparse
import json
def env_config(args):
# preparing config
# # for environment
config = json.load(open(args.config))
config["num_step"] = args.num_step
# config["replay_data_path"] = "replay"
cityflow_config = json.load(open(config['cityflow_config_file']))
roadnetFile = cityflow_config['dir'] + cityflow_config['roadnetFile']
config["lane_phase_info"] = parse_roadnet(roadnetFile)
config["state_time_span"] = args.state_time_span
config["time_span"] = args.time_span
# # for agent
intersection_id = list(config['lane_phase_info'].keys())[0]
phase_list = config['lane_phase_info'][intersection_id]['phase']
logging.info(phase_list)
# config["state_size"] = len(config['lane_phase_info'][intersection_id]['start_lane']) + 1 # 1 is for the current phase. [vehicle_count for each start lane] + [current_phase]
config["state_size"] = len(config['lane_phase_info'][intersection_id]['start_lane'])
config["action_size"] = len(phase_list)
config["batch_size"] = args.batch_size
return config
def agent_config(config_env):
config = ddpg.DEFAULT_CONFIG.copy()
config["num_gpus"] = 0
config["num_workers"] = 1
config["env"] = CityflowGymEnv
config["env_config"] = config_env
return config
# def get_episode_reward(info):
# episode=info
def main():
ray.init()
logging.getLogger().setLevel(logging.INFO)
date = datetime.now().strftime('%Y%m%d_%H%M%S')
parser = argparse.ArgumentParser()
# parser.add_argument('--scenario', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--config', type=str, default='config/global_config.json', help='config file')
parser.add_argument('--algo', type=str, default='DQN', choices=['DQN', 'DDQN', 'DuelDQN'],
help='choose an algorithm')
parser.add_argument('--inference', action="store_true", help='inference or training')
parser.add_argument('--ckpt', type=str, help='inference or training')
parser.add_argument('--epoch', type=int, default=10, help='number of training epochs')
parser.add_argument('--num_step', type=int, default=10 ** 3,
help='number of timesteps for one episode, and for inference')
parser.add_argument('--save_freq', type=int, default=100, help='model saving frequency')
parser.add_argument('--batch_size', type=int, default=128, help='model saving frequency')
parser.add_argument('--state_time_span', type=int, default=5, help='state interval to receive long term state')
parser.add_argument('--time_span', type=int, default=30, help='time interval to collect data')
args = parser.parse_args()
model_dir = "model/{}_{}".format(args.algo, date)
result_dir = "result/{}_{}".format(args.algo, date)
config_env = env_config(args)
# ray.tune.register_env('gym_cityflow', lambda env_config:CityflowGymEnv(config_env))
config_agent = agent_config(config_env)
# # build cityflow environment
trainer = DDPGTrainer(
env=CityflowGymEnv,
config=config_agent)
for i in range(1000):
# Perform one iteration of training the policy with DQN
result = trainer.train()
print(pretty_print(result))
if i % 20 == 0:
checkpoint = trainer.save()
print("checkpoint saved at", checkpoint)
if __name__ == '__main__':
main()