-
Notifications
You must be signed in to change notification settings - Fork 4
/
train_lstm.py
49 lines (37 loc) · 1.2 KB
/
train_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import time
import ray
from ray.rllib.agents import ppo
from ray.rllib.models import ModelCatalog
from aie.aie_env import OBS_SPACE_AGENT, ACT_SPACE_AGENT
from rl.conf import get_base_ppo_conf
from rl.models.tf.fcnet_lstm import RNNModel
def get_conf():
return {
**get_base_ppo_conf(num_workers=10),
"sgd_minibatch_size": 3000, # 60 * 200 * 4 / 3000 = 16 steps of (B=60, L=50, dim)
"lr": 3e-4,
"multiagent": {
"policies_to_train": ["learned"],
"policies": {
"learned": (None, OBS_SPACE_AGENT, ACT_SPACE_AGENT, {
"model": {
"custom_model": "my_model",
'max_seq_len': 50,
},
}),
},
"policy_mapping_fn": lambda x: 'learned',
},
}
def run():
ModelCatalog.register_custom_model("my_model", RNNModel)
trainer = ppo.PPOTrainer(config=get_conf())
t = time.monotonic()
while True:
trainer.train()
checkpoint = trainer.save()
print(time.monotonic() - t, "checkpoint saved at", checkpoint)
t = time.monotonic()
if __name__ == '__main__':
ray.init()
run()