Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
make save_checkpoints_secs work again (#521)
Browse files Browse the repository at this point in the history
The functionality was broken during the adoption of TPU trainer_lib.py
instead of the original trainer_utils.py.
Currently, the default is to save checkpoints each 2000 steps,
while in previous T2T versions the default was each 10 minutes.
  • Loading branch information
martinpopel authored and rsepassi committed Jan 23, 2018
1 parent afba9dc commit f0e638e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
7 changes: 5 additions & 2 deletions tensor2tensor/bin/t2t_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,17 @@ def create_experiment_fn():


def create_run_config(hp):
save_ckpt_steps = max(FLAGS.iterations_per_loop, FLAGS.local_eval_frequency)
if FLAGS.save_checkpoints_secs:
save_ckpt_steps = None
return trainer_lib.create_run_config(
model_dir=os.path.expanduser(FLAGS.output_dir),
master=FLAGS.master,
iterations_per_loop=FLAGS.iterations_per_loop,
num_shards=FLAGS.tpu_num_shards,
log_device_placement=FLAGS.log_device_placement,
save_checkpoints_steps=max(FLAGS.iterations_per_loop,
FLAGS.local_eval_frequency),
save_checkpoints_steps=save_ckpt_steps,
save_checkpoints_secs=FLAGS.save_checkpoints_secs,
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
num_gpus=FLAGS.worker_gpu,
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/utils/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@
"The default value 10,000 hours effectively disables it.")
flags.DEFINE_integer("save_checkpoints_secs", 0,
"Save checkpoints every this many seconds. "
"Default=0 means let tensorflow.contrib.learn.python.learn"
" decide, which is currently set to 600 = 10 minutes.")
"Default=0 means save checkpoints each x steps where x "
"depends on iterations_per_loop and local_eval_frequency.")
flags.DEFINE_bool("log_device_placement", False,
"Whether to log device placement.")

Expand Down
2 changes: 2 additions & 0 deletions tensor2tensor/utils/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def create_run_config(master="",
num_shards=8,
log_device_placement=False,
save_checkpoints_steps=1000,
save_checkpoints_secs=0,
keep_checkpoint_max=20,
keep_checkpoint_every_n_hours=10000,
num_gpus=1,
Expand Down Expand Up @@ -121,6 +122,7 @@ def create_run_config(master="",
"session_config": session_config,
"save_summary_steps": 100,
"save_checkpoints_steps": save_checkpoints_steps,
"save_checkpoints_secs": save_checkpoints_secs,
"keep_checkpoint_max": keep_checkpoint_max,
"keep_checkpoint_every_n_hours": keep_checkpoint_every_n_hours,
"tf_random_seed": random_seed,
Expand Down

0 comments on commit f0e638e

Please sign in to comment.