Skip to content

Commit

Permalink
merge main into callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdeitke committed Apr 28, 2022
2 parents 3a4f2a5 + 474fb84 commit 030d2ce
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
19 changes: 16 additions & 3 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,7 +1309,7 @@ def _save_checkpoint_then_send_checkpoint_for_validation_and_update_last_save_co
self.checkpoints_queue.put(("eval", model_path))
self.last_save = self.training_pipeline.total_steps

def run_pipeline(self):
def run_pipeline(self, valid_on_initial_weights: bool = False):
cur_stage_training_settings = (
self.training_pipeline.current_stage.training_settings
)
Expand All @@ -1333,6 +1333,16 @@ def run_pipeline(self):
)
already_saved_checkpoint = False

if (
valid_on_initial_weights
and should_save_checkpoints
and self.checkpoints_queue is not None
):
if self.worker_id == self.first_local_worker_id:
model_path = self.checkpoint_save()
if self.checkpoints_queue is not None:
self.checkpoints_queue.put(("eval", model_path))

while True:
pipeline_stage_changed = self.training_pipeline.before_rollout(
train_metrics=self._last_aggregated_train_task_metrics
Expand Down Expand Up @@ -1569,7 +1579,10 @@ def run_pipeline(self):
)

def train(
self, checkpoint_file_name: Optional[str] = None, restart_pipeline: bool = False
self,
checkpoint_file_name: Optional[str] = None,
restart_pipeline: bool = False,
valid_on_initial_weights: bool = False,
):
assert (
self.mode == TRAIN_MODE_STR
Expand All @@ -1581,7 +1594,7 @@ def train(
if checkpoint_file_name is not None:
self.checkpoint_load(checkpoint_file_name, restart_pipeline)

self.run_pipeline()
self.run_pipeline(valid_on_initial_weights=valid_on_initial_weights)

training_completed_successfully = True
except KeyboardInterrupt:
Expand Down
7 changes: 6 additions & 1 deletion allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def train_loop(
id: int = 0,
checkpoint: Optional[str] = None,
restart_pipeline: bool = False,
valid_on_initial_weights: bool = False,
*engine_args,
**engine_kwargs,
):
Expand All @@ -372,7 +373,9 @@ def train_loop(
if trainer is not None:
OnPolicyRunner.init_process("Train", id, to_close_on_termination=trainer)
trainer.train(
checkpoint_file_name=checkpoint, restart_pipeline=restart_pipeline
checkpoint_file_name=checkpoint,
restart_pipeline=restart_pipeline,
valid_on_initial_weights=valid_on_initial_weights,
)

@staticmethod
Expand Down Expand Up @@ -446,6 +449,7 @@ def start_train(
max_sampler_processes_per_worker: Optional[int] = None,
save_ckpt_after_every_pipeline_stage: bool = True,
collect_valid_results: bool = False,
valid_on_initial_weights: bool = False,
):
self._initialize_start_train_or_start_test()

Expand Down Expand Up @@ -497,6 +501,7 @@ def start_train(
else model_hash,
first_local_worker_id=worker_ids[0],
distributed_preemption_threshold=self.distributed_preemption_threshold,
valid_on_initial_weights=valid_on_initial_weights,
)
train: BaseProcess = self.mp_ctx.Process(
target=self.train_loop, kwargs=training_kwargs,
Expand Down
10 changes: 10 additions & 0 deletions allenact/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,15 @@ def get_argument_parser():
)
parser.set_defaults(collect_valid_results=False)

parser.add_argument(
"--valid_on_initial_weights",
dest="valid_on_initial_weights",
action="store_true",
required=False,
help="enables running validation on the model with initial weights",
)
parser.set_defaults(valid_on_initial_weights=False)

parser.add_argument(
"--test_expert",
dest="test_expert",
Expand Down Expand Up @@ -454,6 +463,7 @@ def main():
restart_pipeline=args.restart_pipeline,
max_sampler_processes_per_worker=args.max_sampler_processes_per_worker,
collect_valid_results=args.collect_valid_results,
valid_on_initial_weights=args.valid_on_initial_weights,
)
else:
OnPolicyRunner(
Expand Down

0 comments on commit 030d2ce

Please sign in to comment.