Skip to content

Commit

Permalink
Freeze layers for transfer learning.
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel committed Aug 13, 2020
1 parent a6f40a3 commit 459af0f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
30 changes: 28 additions & 2 deletions training/mozilla_voice_stt_training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,35 @@ def get_tower_results(iterator, optimizer, dropout_rates):
# Retain tower's avg losses
tower_avg_losses.append(avg_loss)

train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

# Filter out layers if we want to freeze some
if FLAGS.freeze_source_layers > 0:
filter_vars = []
if FLAGS.freeze_source_layers <= 5:
filter_vars.append("layer_1")
if FLAGS.freeze_source_layers <= 4:
filter_vars.append("layer_2")
if FLAGS.freeze_source_layers <= 3:
filter_vars.append("layer_3")
if FLAGS.freeze_source_layers <= 2:
filter_vars.append("lstm")
if FLAGS.freeze_source_layers <= 1:
filter_vars.append("layer_5")

new_train_vars = list(train_vars)
for fv in filter_vars:
for tv in train_vars:
if fv in tv.name:
new_train_vars.remove(tv)
train_vars = new_train_vars
msg = "Tower {} - Training only variables: {}"
print(msg.format(i, [v.name for v in train_vars]))
else:
print("Tower {} - Training all layers".format(i))

# Compute gradients for model parameters using tower's mini-batch
gradients = optimizer.compute_gradients(avg_loss)
gradients = optimizer.compute_gradients(avg_loss, var_list=train_vars)

# Retain tower's gradients
tower_gradients.append(gradients)
Expand Down Expand Up @@ -654,7 +681,6 @@ def __call__(self, progress, data, **kwargs):

print('-' * 80)


except KeyboardInterrupt:
pass
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
Expand Down
15 changes: 15 additions & 0 deletions training/mozilla_voice_stt_training/util/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,21 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers):
'tensors. Missing variables: {}'.format(missing_var_names))
sys.exit(1)

if FLAGS.load_frozen_graph:
# After training with "freeze_source_layers" the Adam tensors for the frozen layers aren't
# existing anymore because they were not used
# Therefore we have to initialize them again to continue training on such checkpoints
for v in load_vars:
if v.op.name not in vars_in_ckpt:
if 'Adam' in v.name:
init_vars.add(v)
else:
msg = "Tried to load a frozen checkpoint but there was a missing " \
"variable other than the Adam tensors: {}"
log_error(msg.format(v))
sys.exit(1)
load_vars -= init_vars

if allow_drop_layers and FLAGS.drop_source_layers > 0:
# This transfer learning approach requires supplying
# the layers which we exclude from the source model.
Expand Down
2 changes: 2 additions & 0 deletions training/mozilla_voice_stt_training/util/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def create_flags():
# Transfer Learning

f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)')
f.DEFINE_integer('freeze_source_layers', 0, 'use same value as above to freeze the other layers')
f.DEFINE_boolean('load_frozen_graph', False, 'Needed to load a graph checkpoint which was trained with "freeze_source_layers" flag. Allows initialization of missing optimization tensors.')

# Exporting

Expand Down

0 comments on commit 459af0f

Please sign in to comment.