Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add 210 edgestarts for backwards compatibility #985

Merged
merged 12 commits into from
Jul 8, 2020
Merged
37 changes: 4 additions & 33 deletions examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,45 +213,17 @@ def train_rllib(submodule, flags):
run_experiments({flow_params["exp_tag"]: exp_config})


def train_h_baselines(flow_params, args, multiagent):
def train_h_baselines(env_name, args, multiagent):
"""Train policies using SAC and TD3 with h-baselines."""
from hbaselines.algorithms import OffPolicyRLAlgorithm
from hbaselines.utils.train import parse_options, get_hyperparameters
from hbaselines.envs.mixed_autonomy import FlowEnv

flow_params = deepcopy(flow_params)

# Get the command-line arguments that are relevant here
args = parse_options(description="", example_usage="", args=args)

# the base directory that the logged data will be stored in
base_dir = "training_data"

# Create the training environment.
env = FlowEnv(
flow_params,
multiagent=multiagent,
shared=args.shared,
maddpg=args.maddpg,
render=args.render,
version=0
)

# Create the evaluation environment.
if args.evaluate:
eval_flow_params = deepcopy(flow_params)
eval_flow_params['env'].evaluate = True
eval_env = FlowEnv(
eval_flow_params,
multiagent=multiagent,
shared=args.shared,
maddpg=args.maddpg,
render=args.render_eval,
version=1
)
else:
eval_env = None

for i in range(args.n_training):
# value of the next seed
seed = args.seed + i
Expand Down Expand Up @@ -299,8 +271,8 @@ def train_h_baselines(flow_params, args, multiagent):
# Create the algorithm object.
alg = OffPolicyRLAlgorithm(
policy=policy,
env=env,
eval_env=eval_env,
env="flow:{}".format(env_name),
eval_env="flow:{}".format(env_name) if args.evaluate else None,
**hp
)

Expand Down Expand Up @@ -393,8 +365,7 @@ def main(args):
elif flags.rl_trainer.lower() == "stable-baselines":
train_stable_baselines(submodule, flags)
elif flags.rl_trainer.lower() == "h-baselines":
flow_params = submodule.flow_params
train_h_baselines(flow_params, args, multiagent)
train_h_baselines(flags.exp_config, args, multiagent)
else:
raise ValueError("rl_trainer should be either 'rllib', 'h-baselines', "
"or 'stable-baselines'.")
Expand Down
16 changes: 16 additions & 0 deletions flow/visualize/time_space_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,22 @@ def _get_abs_pos(df, params):
}
elif params['network'] == HighwayNetwork:
return df['x']
elif params['network'] == I210SubNetwork:
edgestarts = {
'119257914': -5.0999999999995795,
'119257908#0': 56.49000000018306,
':300944379_0': 56.18000000000016,
':300944436_0': 753.4599999999871,
'119257908#1-AddedOnRampEdge': 756.3299999991157,
':119257908#1-AddedOnRampNode_0': 853.530000000022,
'119257908#1': 856.7699999997207,
':119257908#1-AddedOffRampNode_0': 1096.4499999999707,
'119257908#1-AddedOffRampEdge': 1099.6899999995558,
':1686591010_1': 1198.1899999999541,
'119257908#2': 1203.6499999994803,
':1842086610_1': 1780.2599999999056,
'119257908#3': 1784.7899999996537,
}
else:
edgestarts = defaultdict(float)

Expand Down
10 changes: 5 additions & 5 deletions tests/fast_tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,22 @@ class TestHBaselineExamples(unittest.TestCase):
confirming that it runs.
"""
@staticmethod
def run_exp(flow_params, multiagent):
def run_exp(env_name, multiagent):
train_h_baselines(
flow_params=flow_params,
env_name=env_name,
args=[
flow_params["env_name"].__name__,
env_name,
"--initial_exploration_steps", "1",
"--total_steps", "10"
],
multiagent=multiagent,
)

def test_singleagent_ring(self):
self.run_exp(singleagent_ring.copy(), multiagent=False)
self.run_exp("singleagent_ring", multiagent=False)

def test_multiagent_ring(self):
self.run_exp(multiagent_ring.copy(), multiagent=True)
self.run_exp("multiagent_ring", multiagent=True)


class TestRllibExamples(unittest.TestCase):
Expand Down