generated from minerllabs/basalt_competition_submission_template
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain_submission_code.py
174 lines (152 loc) · 6.72 KB
/
train_submission_code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from agents.bc import BCAgent
from agents.soft_q import SoftQAgent
import aicrowd_helper
from algorithms.offline import SupervisedLearning
from algorithms.online_imitation import OnlineImitation
from algorithms.sac_iqlearn import IQLearnSAC
from algorithms.sac_curiosity import CuriositySAC
from algorithms.curious_iq import CuriousIQ
from core.datasets import ReplayBuffer, SequenceReplayBuffer
from core.datasets import TrajectoryStepDataset, TrajectorySequenceDataset
from core.environment import start_env
from core.trajectory_generator import TrajectoryGenerator
from modules.termination_critic import TerminationCritic
from utility.config import get_config, parse_args
from utility.parser import Parser
import logging
import os
from pathlib import Path
import time
import coloredlogs
from flatten_dict import flatten
from omegaconf import DictConfig, OmegaConf
from pyvirtualdisplay import Display
import numpy as np
import torch as th
from torch.profiler import profile, record_function, ProfilerActivity, schedule
import wandb
coloredlogs.install(logging.DEBUG)
# You need to ensure that your submission is trained by launching less
# than MINERL_TRAINING_MAX_INSTANCES instances
MINERL_TRAINING_MAX_INSTANCES = int(os.getenv('MINERL_TRAINING_MAX_INSTANCES', 5))
# The dataset is available in data/ directory from repository root.
MINERL_DATA_ROOT = os.getenv('MINERL_DATA_ROOT', 'data/')
# You need to ensure that your submission is trained within allowed training time.
MINERL_TRAINING_TIMEOUT = int(os.getenv('MINERL_TRAINING_TIMEOUT_MINUTES', 4 * 24 * 60))
# You need to ensure that your submission is trained by launching
# less than MINERL_TRAINING_MAX_INSTANCES instances
MINERL_TRAINING_MAX_INSTANCES = int(os.getenv('MINERL_TRAINING_MAX_INSTANCES', 5))
# Optional: You can view best effort status of your instances with the help of parser.py
# This will give you current state like number of steps completed, instances launched
# and so on.
# Make your you keep a tap on the numbers to avoid breaching any limits.
parser = Parser(
'performance/',
maximum_instances=MINERL_TRAINING_MAX_INSTANCES,
raise_on_error=False,
no_entry_poll_timeout=600,
submission_timeout=MINERL_TRAINING_TIMEOUT * 60,
initial_poll_timeout=600
)
os.environ["MINERL_DATA_ROOT"] = MINERL_DATA_ROOT
def main(args=None, config=None):
"""
This function will be called for training phase.
This should produce and save same files you upload during your submission.
"""
aicrowd_helper.training_start()
logging.basicConfig(level=logging.INFO)
logging.getLogger().setLevel(logging.INFO)
if not args:
args = parse_args()
if not config:
config = get_config(args)
environment = config.env.name
if config.wandb:
wandb.init(
project=config.project_name,
entity="basalt",
notes="test",
config=flatten(OmegaConf.to_container(config, resolve=True),
reducer='dot'),
)
# Start Virual Display
if args.virtual_display:
display = Display(visible=0, size=(400, 300))
display.start()
# Start env
if config.method.algorithm != 'supervised_learning':
if args.debug_env:
print('Starting Debug Env')
else:
print(f'Starting Env: {environment}')
env = start_env(config, debug_env=args.debug_env)
else:
env = None
replay_buffer = ReplayBuffer(config) if config.model.lstm_layers == 0 \
else SequenceReplayBuffer(config)
iter_count = 0
if config.method.online and config.method.starting_steps > 0:
replay_buffer = TrajectoryGenerator(
env, None, config, replay_buffer, training=True
).random_trajectories(config.method.starting_steps)
iter_count += config.method.starting_steps
# initialize dataset, agent, algorithm
if config.method.expert_dataset:
if config.model.lstm_layers == 0:
expert_dataset = TrajectoryStepDataset(config,
debug_dataset=args.debug_env)
else:
expert_dataset = TrajectorySequenceDataset(config,
debug_dataset=args.debug_env)
if config.method.loss_function == 'iqlearn':
agent = SoftQAgent(config)
elif config.method.loss_function == 'bc':
agent = BCAgent(config)
# if config.method.algorithm == 'curious_IQ':
# training_algorithm = CuriousIQ(expert_dataset, config,
# initial_replay_buffer=replay_buffer,
# initial_iter_count=iter_count)
if config.method.algorithm == 'sac' and config.method.loss_function == 'iqlearn':
training_algorithm = IQLearnSAC(expert_dataset, config,
initial_replay_buffer=replay_buffer,
initial_iter_count=iter_count)
if config.method.algorithm == 'online_imitation':
training_algorithm = OnlineImitation(expert_dataset, agent, config,
initial_replay_buffer=replay_buffer,
initial_iter_count=iter_count)
elif config.method.algorithm == 'supervised_learning':
training_algorithm = SupervisedLearning(expert_dataset, agent, config)
# run algorithm
if not args.profile:
agent, replay_buffer = training_algorithm(env)
else:
print('Training with profiler')
profile_dir = f'./logs/{training_algorithm.name}/'
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
on_trace_ready=th.profiler.tensorboard_trace_handler(profile_dir),
schedule=schedule(skip_first=32, wait=5,
warmup=1, active=3, repeat=2)) as prof:
with record_function("model_inference"):
agent, replay_buffer = training_algorithm(env, profiler=prof)
if args.wandb:
profile_art = wandb.Artifact("trace", type="profile")
for profile_file_path in Path(profile_dir).iterdir():
profile_art.add_file(profile_file_path)
profile_art.save()
# save model
if not args.debug_env:
agent_save_path = os.path.join('train', f'{training_algorithm.name}.pth')
agent.save(agent_save_path)
if args.wandb:
model_art = wandb.Artifact("agent", type="model")
model_art.add_file(agent_save_path)
model_art.save()
# Training 100% Completed
aicrowd_helper.register_progress(1)
if args.virtual_display:
display.stop()
if env is not None:
env.close()
if __name__ == "__main__":
main()