Skip to content

Commit

Permalink
Valid on initial model weights
Browse files Browse the repository at this point in the history
  • Loading branch information
jordis-ai2 committed Apr 26, 2022
1 parent 456a747 commit d9d5992
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
19 changes: 16 additions & 3 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ def _save_checkpoint_then_send_checkpoint_for_validation_and_update_last_save_co
self.checkpoints_queue.put(("eval", model_path))
self.last_save = self.training_pipeline.total_steps

def run_pipeline(self):
def run_pipeline(self, valid_on_initial_weights: bool = False):
cur_stage_training_settings = (
self.training_pipeline.current_stage.training_settings
)
Expand All @@ -1336,6 +1336,16 @@ def run_pipeline(self):
)
already_saved_checkpoint = False

if (
valid_on_initial_weights
and should_save_checkpoints
and self.checkpoints_queue is not None
):
if self.worker_id == self.first_local_worker_id:
model_path = self.checkpoint_save()
if self.checkpoints_queue is not None:
self.checkpoints_queue.put(("eval", model_path))

while True:
pipeline_stage_changed = self.training_pipeline.before_rollout(
train_metrics=self._last_aggregated_train_task_metrics
Expand Down Expand Up @@ -1572,7 +1582,10 @@ def run_pipeline(self):
)

def train(
self, checkpoint_file_name: Optional[str] = None, restart_pipeline: bool = False
self,
checkpoint_file_name: Optional[str] = None,
restart_pipeline: bool = False,
valid_on_initial_weights: bool = False,
):
assert (
self.mode == TRAIN_MODE_STR
Expand All @@ -1584,7 +1597,7 @@ def train(
if checkpoint_file_name is not None:
self.checkpoint_load(checkpoint_file_name, restart_pipeline)

self.run_pipeline()
self.run_pipeline(valid_on_initial_weights=valid_on_initial_weights)

training_completed_successfully = True
except KeyboardInterrupt:
Expand Down
7 changes: 6 additions & 1 deletion allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def train_loop(
id: int = 0,
checkpoint: Optional[str] = None,
restart_pipeline: bool = False,
valid_on_initial_weights: bool = False,
*engine_args,
**engine_kwargs,
):
Expand All @@ -333,7 +334,9 @@ def train_loop(
if trainer is not None:
OnPolicyRunner.init_process("Train", id, to_close_on_termination=trainer)
trainer.train(
checkpoint_file_name=checkpoint, restart_pipeline=restart_pipeline
checkpoint_file_name=checkpoint,
restart_pipeline=restart_pipeline,
valid_on_initial_weights=valid_on_initial_weights,
)

@staticmethod
Expand Down Expand Up @@ -407,6 +410,7 @@ def start_train(
max_sampler_processes_per_worker: Optional[int] = None,
save_ckpt_after_every_pipeline_stage: bool = True,
collect_valid_results: bool = False,
valid_on_initial_weights: bool = False,
):
self._initialize_start_train_or_start_test()

Expand Down Expand Up @@ -457,6 +461,7 @@ def start_train(
if model_hash is None
else model_hash,
first_local_worker_id=worker_ids[0],
valid_on_initial_weights=valid_on_initial_weights,
)
train: BaseProcess = self.mp_ctx.Process(
target=self.train_loop, kwargs=training_kwargs,
Expand Down
10 changes: 10 additions & 0 deletions allenact/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,15 @@ def get_argument_parser():
)
parser.set_defaults(collect_valid_results=False)

parser.add_argument(
"--valid_on_initial_weights",
dest="valid_on_initial_weights",
action="store_true",
required=False,
help="enables running validation on the model with initial weights",
)
parser.set_defaults(collect_valid_results=False)

parser.add_argument(
"--test_expert",
dest="test_expert",
Expand Down Expand Up @@ -443,6 +452,7 @@ def main():
restart_pipeline=args.restart_pipeline,
max_sampler_processes_per_worker=args.max_sampler_processes_per_worker,
collect_valid_results=args.collect_valid_results,
valid_on_initial_weights=args.valid_on_initial_weights,
)
else:
OnPolicyRunner(
Expand Down

0 comments on commit d9d5992

Please sign in to comment.